Compare commits

..

53 Commits

Author SHA1 Message Date
NanoCode012
fbf3ca86c9 feat: add support for qwen25 vl for multimodal 2025-02-18 12:42:29 +07:00
Sunny
2de866e92f revert seq len to 8192 2024-12-08 22:30:20 -05:00
Sunny
295e07dcca settings 2024-12-08 22:22:18 -05:00
bursteratom
3c07b6d6b1 lint 2024-12-06 16:06:57 -05:00
bursteratom
89dae7dc6d lora_target_module 2024-12-06 15:41:09 -05:00
bursteratom
1b54af8e54 lora config 2024-12-06 15:27:18 -05:00
bursteratom
ca7b56cba3 lora config 2024-12-06 15:26:06 -05:00
bursteratom
ea8269d2eb lora config 2024-12-06 15:23:24 -05:00
bursteratom
13ca7ed087 comment out lora target 2024-12-06 15:21:08 -05:00
bursteratom
0dfd8541ee lora config qwen2vl 2024-12-06 14:56:51 -05:00
bursteratom
75e1d3537f qwen2_vl get_text_config 2024-12-06 14:54:06 -05:00
bursteratom
2b7f3bd6ab qwen2_vl get_text_config 2024-12-06 14:52:17 -05:00
bursteratom
d85a229afe get_text_config 2024-12-06 14:50:05 -05:00
bursteratom
355cd7c872 update is_multimodal requirement to include qwen2_vl 2024-12-06 14:43:50 -05:00
bursteratom
eab1638686 lint 2024-12-06 14:37:32 -05:00
bursteratom
a3a4d22709 config init qwen2-vl chat template 2024-12-06 14:24:03 -05:00
bursteratom
f9eb7d8663 qwen2 example 2024-12-06 14:22:08 -05:00
bursteratom
343771a6d3 lint 2024-12-06 13:15:49 -05:00
bursteratom
d2c32d0cba lint 2024-12-06 13:04:42 -05:00
bursteratom
cec9887609 add llava chat template to config 2024-12-06 12:57:20 -05:00
bursteratom
88b2cae748 llava template 2024-12-06 12:54:43 -05:00
bursteratom
aea2565938 for test only 2024-12-06 11:54:07 -05:00
bursteratom
1ad56303b2 lint 2024-12-05 15:34:04 -05:00
bursteratom
dc055a4ef7 lint 2024-12-05 14:59:51 -05:00
bursteratom
169116a50f llava example 2024-12-05 12:58:30 -05:00
bursteratom
43e412f660 comment 2024-12-04 13:18:25 -05:00
Wing Lian
7aa57803e1 fix optimizer reset for relora sft (#1414)
* fix optimizer reset

* set states to reset for 8bit optimizers and handle quantile runtime error for embeddings

* fix relora test to check grad_norm

* use flash attn for relora and tweak hyperparams for test

* fix messages field for test dataset
2024-12-04 12:33:29 -05:00
NanoCode012
1969fa3bf0 fix(readme): update cuda instructions during preprocess (#2114) [skip ci] 2024-12-04 12:33:29 -05:00
NanoCode012
4078f37076 feat: add cut_cross_entropy (#2091)
* feat: add cut_cross_entropy

* fix: add to input

* fix: remove from setup.py

* feat: refactor into an integration

* chore: ignore lint

* feat: add test for cce

* fix: set max_steps for liger test

* chore: Update base model following suggestion

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* chore: update special_tokens following suggestion

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* chore: remove with_temp_dir following comments

* fix: plugins aren't loaded

* chore: update quotes in error message

* chore: lint

* chore: lint

* feat: enable FA on test

* chore: refactor get_pytorch_version

* fix: lock cce commit version

* fix: remove subclassing UT

* fix: downcast even if not using FA and config check

* feat: add test to check different attentions

* feat: add install to CI

* chore: refactor to use parametrize for attention

* fix: pytest not detecting test

* feat: handle torch lower than 2.4

* fix args/kwargs to match docs

* use release version cut-cross-entropy==24.11.4

* fix quotes

* fix: use named params for clarity for modal builder

* fix: handle install from pip

* fix: test check only top level module install

* fix: re-add import check

* uninstall existing version if no transformers submodule in cce

* more dataset fixtures into the cache

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2024-12-04 12:33:29 -05:00
Wing Lian
f073af6d99 fix merge conflict of duplicate max_steps in config for relora (#2116) 2024-12-04 12:33:29 -05:00
Wing Lian
139d2612fa fix so inference can be run against quantized models without adapters (#1834)
* fix so inference can be run against quantized models without adapters

* Update error msg [skip e2e]

Co-authored-by: NanoCode012 <nano@axolotl.ai>

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-12-04 12:33:29 -05:00
Sunny Liu
20573fd13e Add ds model card, rebased (#2101) [skip ci]
* rebased add_ds_model_card

* manual rebasing

* fix redundancy

* lint

* include case when ds_tag is none

* conform to kwargs in create_model_card
2024-12-04 12:33:29 -05:00
NanoCode012
2b7b4af81c fix(vlm): handle legacy conversation data format and check image in data (#2018) [skip ci]
* fix: handle legacy conversation data format and check image in data

* feat: add test for llama vision

* feat: add max_steps to test

* fix: incorrect indent and return preprocess

* feat: use smaller model and dataset

* chore: add extra config for sharegpt dataset
2024-12-04 12:33:29 -05:00
Sunny Liu
d56260c8d5 Check torch version for ADOPT optimizer + integrating new ADOPT updates (#2104)
* added torch check for adopt, wip

* lint

* gonna put torch version checking somewhere else

* added ENVcapabilities class for torch version checking

* lint + pydantic

* ENVCapabilities -> EnvCapabilities

* forgot to git add v0_4_1/__init__.py

* removed redundancy

* add check if env_capabilities not specified

* make env_capabilities compulsory [skip e2e]

* fixup env_capabilities

* modified test_validation.py to accomodate env_capabilities

* adopt torch version test [skip e2e]

* raise error

* test correct torch version

* test torch version above requirement

* Update src/axolotl/utils/config/models/input/v0_4_1/__init__.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* removed unused is_totch_min

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-12-04 12:33:29 -05:00
Wing Lian
cac785ec0e use pytest sugar and verbose for more info during ci (#2112) [skip ci]
* use pytest sugar and verbose for more info during ci

* also run test suite when test requirements or cicd.sh changes

* also on PR too
2024-12-04 12:33:29 -05:00
Wing Lian
e62991edef make the eval size smaller for the resume test (#2111) [skip ci] 2024-12-04 12:33:29 -05:00
Wing Lian
fd9e7b55f6 build causal_conv1d and mamba-ssm into the base image (#2113)
* build causal_conv1d and mamba-ssm into the base image

* also build base images on changes to Dockerfile-base and base workflow yaml
2024-12-04 12:33:29 -05:00
Wing Lian
c0c53eb62f various tests fixes for flakey tests (#2110)
* add mhenrichsen/alpaca_2k_test with revision dataset download fixture for flaky tests

* log slowest tests

* pin pynvml==11.5.3

* fix load local hub path

* optimize for speed w smaller models and val_set_size

* replace pynvml

* make the resume from checkpoint e2e faster

* make tests smaller
2024-12-04 12:33:29 -05:00
Oliver Molenschot
b0fbd4d11d Add Exact Deduplication Feature to Preprocessing Pipeline (#2072)
* Add example YAML file for training Mistral using DPO

* added deduplication code

* Add exact deduplication feature and update examples

* Improve deduplication for train/eval overlap

Changed the deduplication function to use a more memory-efficient hashing method. Applied Git suggestions to improve clarity and maintainability.\n\nThe deduplication now handles cases where train and eval datasets have overlapping elements.

* Improve deduplication for train/eval overlap

Changed the deduplication function to use a more memory-efficient hashing method. Applied Git suggestions to improve clarity and maintainability.\n\nThe deduplication now handles cases where train and eval datasets have overlapping elements.

* Apply suggestions from code review

To handle the original case where we do not do deduplication

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Improve false collision detection to ensure dataset integrity

- Added test cases to simulate and verify handling of forced hash collisions between datasets.
- Ensured that datasets with identical hashes but different content are correctly identified, preventing incorrect deduplication.
- Updated unit tests to include scenarios where collisions occur across both training and evaluation datasets, as well as within a single dataset.

* Moved the constants file to the tests folder

- Relocated `constants.py` to the `tests` folder to improve modularity and maintain a clear separation between source and test files.
- Renamed `cicd/tests.py` to `cicd/cicd_tests.py` to resolve a conflict with `tests/__init__.py`, which caused Mypy to fail due to duplicate module names.
- Updated all references to `cicd.tests` in the codebase to `cicd.cicd_tests` to reflect the renaming and ensure compatibility.
- These changes ensure Mypy passes the pre-commit hook and maintain alignment with the project's structure.

* revert some changes from previous commit and fix relative import

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2024-12-04 12:33:29 -05:00
Wing Lian
1a70d4d6a4 add e2e tests for Unsloth qlora and test the builds (#2093)
* see if unsloth installs cleanly in ci

* check unsloth install on regular tests, not sdist

* fix ampere check exception for ci

* use cached_property instead

* add an e2e test for unsloth qlora

* reduce seq len and mbsz to prevent oom in ci

* add checks for fp16 and sdp_attention

* pin unsloth to a specific release

* add unsloth to docker image too

* fix flash attn xentropy patch

* fix loss, add check for loss when using fa_xentropy

* fix special tokens for test

* typo

* test fa xentropy with and without gradient accum

* pr feedback changes
2024-12-04 12:33:29 -05:00
Wing Lian
d8787a433f support seperate lr for embeddings, similar to loraplus (#1910) [skip ci]
* support seperate lr for embeddings, similar to loraplus

* add test case for train w lr embedding scale

* use kwarg for optimizer

* make sure to handle the optimizer creation

* make sure to handle for embedding_lr too

* use smollm for e2e, check for embeddings lr first before wdecay
2024-12-04 12:33:29 -05:00
NanoCode012
e775422269 fix: ds3 and fsdp lmbench eval (#2102) [ski[p ci]
* fix: ds3 and fsdp lmbench eval

* chore: update comment

* fix: test signature
2024-12-04 12:33:29 -05:00
Wing Lian
97178f5960 add finetome dataset to fixtures, check eval_loss in test (#2106) [skip ci]
* add finetome dataset to fixtures, check eval_loss in test

* add qwen 0.5b to pytest session fixture
2024-12-04 12:33:29 -05:00
bursteratom
4698eed43f set pixtral chat template 2024-12-04 12:11:21 -05:00
bursteratom
f84c3b37e7 lint 2024-12-04 11:59:45 -05:00
bursteratom
c39971c659 stuff 2024-11-27 10:52:36 -05:00
bursteratom
33a178c788 val config pixtral chat template 2024-11-27 10:36:23 -05:00
bursteratom
db15605e7e pixral chat template 2024-11-27 10:34:19 -05:00
bursteratom
9e112bc8b5 lint 2024-11-27 10:33:35 -05:00
bursteratom
e038410778 lint 2024-11-27 10:24:37 -05:00
bursteratom
f4385c3cf4 add special tokens 2024-11-27 10:18:45 -05:00
bursteratom
d58c772df6 pixtral flash-attn false 2024-11-27 10:16:17 -05:00
bursteratom
69265a53b5 stuff 2024-11-27 09:53:41 -05:00
69 changed files with 752 additions and 2276 deletions

View File

@@ -23,15 +23,9 @@ jobs:
runs-on: ubuntu-latest
strategy:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
- python_version: "3.10"
pytorch_version: "2.5.1"
timeout-minutes: 20
steps:
@@ -61,18 +55,12 @@ jobs:
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Run tests
run: |
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/
pytest --ignore=tests/e2e/ tests/
- name: cleanup pip cache
run: |

View File

@@ -10,7 +10,6 @@ on:
- '.github/workflows/*.yml'
- 'requirements-tests.txt'
- 'cicd/cicd.sh'
- 'cicd/Dockerfile.jinja'
pull_request:
paths:
- '**.py'
@@ -18,7 +17,6 @@ on:
- '.github/workflows/*.yml'
- 'requirements-tests.txt'
- 'cicd/cicd.sh'
- 'cicd/Dockerfile.jinja'
workflow_dispatch:
# Cancel jobs on the same ref if a new one is triggered
@@ -45,15 +43,9 @@ jobs:
runs-on: ubuntu-latest
strategy:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
- python_version: "3.10"
pytorch_version: "2.5.1"
timeout-minutes: 20
steps:
@@ -83,14 +75,9 @@ jobs:
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Run tests
run: |
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/
pytest -n8 --ignore=tests/e2e/ tests/
- name: cleanup pip cache
run: |
@@ -101,7 +88,6 @@ jobs:
runs-on: ubuntu-latest
strategy:
fail-fast: false
max-parallel: 1
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1"]
@@ -131,18 +117,11 @@ jobs:
pip3 show torch
python3 setup.py sdist
pip3 install dist/axolotl*.tar.gz
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Run tests
run: |
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/
pytest -n8 --ignore=tests/e2e/ tests/
- name: cleanup pip cache
run: |

235
README.md
View File

@@ -41,12 +41,9 @@ Features:
## Table of Contents
- [Axolotl](#axolotl)
- [Table of Contents](#table-of-contents)
- [Axolotl supports](#axolotl-supports)
- [Quickstart ⚡](#quickstart-)
- [Usage](#usage)
- [Badge ❤🏷️](#badge-)
- [Contributing 🤝](#contributing-)
- [Sponsors 🤝❤](#sponsors-)
- [Axolotl supports](#axolotl-supports)
- [Advanced Setup](#advanced-setup)
- [Environment](#environment)
- [Docker](#docker)
@@ -78,6 +75,14 @@ Features:
- [Tokenization Mismatch b/w Inference \& Training](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl)
- [Need help? 🙋](#need-help-)
- [Badge ❤🏷️](#badge-)
- [Community Showcase](#community-showcase)
- [Contributing 🤝](#contributing-)
- [Sponsors 🤝❤](#sponsors-)
- [💎 Diamond Sponsors - Contact directly](#-diamond-sponsors---contact-directly)
- [🥇 Gold Sponsors - $5000/mo](#-gold-sponsors---5000mo)
- [🥈 Silver Sponsors - $1000/mo](#-silver-sponsors---1000mo)
- [🥉 Bronze Sponsors - $500/mo](#-bronze-sponsors---500mo)
</td>
<td>
@@ -100,11 +105,36 @@ Features:
</tr>
</table>
## Axolotl supports
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Mixtral8X22 | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
| Jamba | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
✅: supported
❌: not supported
❓: untested
## Quickstart ⚡
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: *Nvidia* GPU (Ampere architecture or newer for `bf16` and Flash Attention) or *AMD* GPU, Python >=3.10 and PyTorch >=2.3.1.
**Requirements**: Nvidia GPU (Ampere architecture or newer for `bf16` and Flash Attention), Python >=3.10 and PyTorch >=2.3.1.
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl
@@ -135,117 +165,6 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
```
### Axolotl CLI
If you've installed this package using `pip` from source, we now support a new, more
streamlined CLI using [click](https://click.palletsprojects.com/en/stable/). Rewriting
the above commands:
```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/openllama-3b/lora.yml
# finetune lora
axolotl train examples/openllama-3b/lora.yml
# inference
axolotl inference examples/openllama-3b/lora.yml \
--lora-model-dir="./outputs/lora-out"
# gradio
axolotl inference examples/openllama-3b/lora.yml \
--lora-model-dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
```
We've also added a new command for fetching `examples` and `deepspeed_configs` to your
local machine. This will come in handy when installing `axolotl` from PyPI.
```bash
# Fetch example YAML files (stores in "examples/" folder)
axolotl fetch examples
# Fetch deepspeed config files (stores in "deepspeed_configs/" folder)
axolotl fetch deepspeed_configs
# Optionally, specify a destination folder
axolotl fetch examples --dest path/to/folder
```
## Badge ❤🏷️
Building something cool with Axolotl? Consider adding a badge to your model card.
```markdown
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
```
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
## Sponsors 🤝❤
If you love axolotl, consider sponsoring the project by reaching out directly to [wing@axolotl.ai](mailto:wing@axolotl.ai).
---
- [Modal](https://modal.com/) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune LLM models, run protein folding simulations, and much more.
---
## Contributing 🤝
Please read the [contributing guide](./.github/CONTRIBUTING.md)
Bugs? Please check the [open issues](https://github.com/axolotl-ai-cloud/axolotl/issues/bug) else create a new Issue.
PRs are **greatly welcome**!
Please run the quickstart instructions followed by the below to setup env:
```bash
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pre-commit install
# test
pytest tests/
# optional: run against all files
pre-commit run --all-files
```
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors">
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
</a>
## Axolotl supports
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Mixtral8X22 | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
| Jamba | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
✅: supported
❌: not supported
❓: untested
## Advanced Setup
### Environment
@@ -763,6 +682,86 @@ See [this debugging guide](docs/debugging.qmd) for tips on debugging Axolotl, al
## Need help? 🙋
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where our community members can help you.
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we our community members can help you.
Need dedicated support? Please contact us at [wing@axolotl.ai](ailto:wing@axolotl.ai) for dedicated support options.
Need dedicated support? Please contact us at [✉wing@openaccessaicollective.org](mailto:wing@openaccessaicollective.org) for dedicated support options.
## Badge ❤🏷️
Building something cool with Axolotl? Consider adding a badge to your model card.
```markdown
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
```
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
## Community Showcase
Check out some of the projects and models that have been built using Axolotl! Have a model you'd like to add to our Community Showcase? Open a PR with your model.
Open Access AI Collective
- [Minotaur 13b](https://huggingface.co/openaccess-ai-collective/minotaur-13b-fixed)
- [Manticore 13b](https://huggingface.co/openaccess-ai-collective/manticore-13b)
- [Hippogriff 30b](https://huggingface.co/openaccess-ai-collective/hippogriff-30b-chat)
PocketDoc Labs
- [Dan's PersonalityEngine 13b LoRA](https://huggingface.co/PocketDoc/Dans-PersonalityEngine-13b-LoRA)
## Contributing 🤝
Please read the [contributing guide](./.github/CONTRIBUTING.md)
Bugs? Please check the [open issues](https://github.com/axolotl-ai-cloud/axolotl/issues/bug) else create a new Issue.
PRs are **greatly welcome**!
Please run the quickstart instructions followed by the below to setup env:
```bash
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pre-commit install
# test
pytest tests/
# optional: run against all files
pre-commit run --all-files
```
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors">
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
</a>
## Sponsors 🤝❤
OpenAccess AI Collective is run by volunteer contributors such as [winglian](https://github.com/winglian),
[NanoCode012](https://github.com/NanoCode012), [tmm1](https://github.com/tmm1),
[mhenrichsen](https://github.com/mhenrichsen), [casper-hansen](https://github.com/casper-hansen),
[hamelsmu](https://github.com/hamelsmu) and many more who help us accelerate forward by fixing bugs, answering
community questions and implementing new features. Axolotl needs donations from sponsors for the compute needed to
run our unit & integration tests, troubleshooting community issues, and providing bounties. If you love axolotl,
consider sponsoring the project via [GitHub Sponsors](https://github.com/sponsors/OpenAccess-AI-Collective),
[Ko-fi](https://ko-fi.com/axolotl_ai) or reach out directly to
[wing@openaccessaicollective.org](mailto:wing@openaccessaicollective.org).
---
#### 💎 Diamond Sponsors - [Contact directly](mailto:wing@openaccessaicollective.org)
---
#### 🥇 Gold Sponsors - $5000/mo
---
#### 🥈 Silver Sponsors - $1000/mo
---
#### 🥉 Bronze Sponsors - $500/mo
- [JarvisLabs.ai](https://jarvislabs.ai)
---

View File

@@ -4,6 +4,7 @@ ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}"
ENV BNB_CUDA_VERSION="{{ CUDA }}"
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"

View File

@@ -1,7 +1,6 @@
#!/bin/bash
set -e
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 -n8 --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest -v --durations=10 -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -5,6 +5,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.1.2"
ENV PYTORCH_VERSION=$PYTORCH_VERSION

View File

@@ -16,7 +16,7 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \

View File

@@ -2,7 +2,7 @@ ARG BASE_TAG=main
FROM axolotlai/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"

View File

@@ -2,7 +2,7 @@ ARG BASE_TAG=main
FROM axolotlai/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"

View File

@@ -5,6 +5,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.1.2"
ARG GITHUB_REF="main"

View File

@@ -0,0 +1,63 @@
base_model: llava-hf/llava-1.5-7b-hf
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: llava
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -0,0 +1,65 @@
base_model: mistral-community/pixtral-12b
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: pixtral
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -0,0 +1,63 @@
base_model: Qwen/Qwen2-VL-7B-Instruct
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: qwen2_vl
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -1,19 +0,0 @@
[build-system]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
build-backend = "setuptools.build_meta"
[project]
name = "axolotl"
dynamic = ["version", "dependencies", "optional-dependencies"]
description = "LLM Trainer"
readme = "README.md"
requires-python = ">=3.10"
[project.scripts]
axolotl = "axolotl.cli.main:main"
[project.urls]
Homepage = "https://axolotl-ai-cloud.github.io/axolotl/"
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
[tool.setuptools_scm]

View File

@@ -1,10 +1,10 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.14.0
transformers>=4.46.3
peft==0.13.2
transformers==4.46.3
tokenizers>=0.20.1
bitsandbytes==0.45.0
accelerate==1.2.0
bitsandbytes==0.44.1
accelerate==1.1.0
datasets==3.1.0
deepspeed==0.15.4
pydantic==2.6.3
@@ -31,7 +31,7 @@ art
gradio==3.50.2
tensorboard
python-dotenv==1.0.1
autoawq==0.2.7.post3
autoawq==0.2.7.post2
triton>=2.3.0
liger-kernel==0.4.2
@@ -42,7 +42,7 @@ s3fs>=2024.5.0
gcsfs>=2024.5.0
# adlfs
trl==0.12.1
trl==0.12.0
zstandard==0.22.0
fastcore

View File

@@ -16,11 +16,11 @@ if v < V("2.4.0"):
sys.exit(0)
cce_spec = importlib.util.find_spec("cut_cross_entropy")
cce_spec_transformers = importlib.util.find_spec("cut_cross_entropy.transformers")
UNINSTALL_PREFIX = ""
if cce_spec:
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
if cce_spec and not cce_spec_transformers:
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
print(
UNINSTALL_PREFIX

View File

@@ -1,4 +1,5 @@
"""setup.py for axolotl"""
import platform
import re
from importlib.metadata import PackageNotFoundError, version
@@ -92,16 +93,16 @@ def parse_requirements():
install_requires, dependency_links = parse_requirements()
setup(
name="axolotl",
version="0.5.2",
description="LLM Trainer",
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
package_dir={"": "src"},
packages=find_packages("src"),
install_requires=install_requires,
dependency_links=dependency_links,
entry_points={
"console_scripts": [
"axolotl=axolotl.cli.main:main",
],
},
extras_require={
"flash-attn": [
"flash-attn==2.7.0.post2",

View File

@@ -1,8 +0,0 @@
"""Axolotl - Train and fine-tune large language models"""
try:
from importlib.metadata import version
__version__ = version("axolotl")
except ImportError:
__version__ = "unknown"

View File

@@ -380,7 +380,7 @@ def choose_config(path: Path):
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return str(yaml_files[0])
return yaml_files[0]
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
@@ -391,7 +391,7 @@ def choose_config(path: Path):
try:
choice = int(input("Enter the number of your choice: "))
if 1 <= choice <= len(yaml_files):
chosen_file = str(yaml_files[choice - 1])
chosen_file = yaml_files[choice - 1]
else:
print("Invalid choice. Please choose a number from the list.")
except ValueError:
@@ -432,8 +432,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None
prepare_plugins(cfg)
cfg = validate_config(
cfg,
capabilities={
@@ -442,10 +440,12 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
"compute_capability": gpu_version,
},
env_capabilities={
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0],
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
},
)
prepare_plugins(cfg)
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)

View File

@@ -2,7 +2,6 @@
CLI to run inference on a trained model
"""
from pathlib import Path
from typing import Union
import fire
import transformers
@@ -17,7 +16,7 @@ from axolotl.cli import (
from axolotl.common.cli import TrainerCliArgs
def do_cli(config: Union[Path, str] = Path("examples/"), gradio=False, **kwargs):
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, inference=True, **kwargs)

View File

@@ -1,233 +0,0 @@
"""CLI definition for various axolotl commands."""
# pylint: disable=redefined-outer-name
import subprocess # nosec B404
from typing import Optional
import click
import axolotl
from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
build_command,
fetch_from_github,
)
from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli():
"""Axolotl CLI - Train and fine-tune large language models"""
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig)
def preprocess(config: str, **kwargs):
"""Preprocess datasets before training."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
from axolotl.cli.preprocess import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU training",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def train(config: str, accelerate: bool, **kwargs):
"""Train or fine-tune a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.train import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU inference",
)
@click.option(
"--lora-model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing LoRA model",
)
@click.option(
"--base-model",
type=click.Path(exists=True, path_type=str),
help="Path to base model for non-LoRA models",
)
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
@click.option("--load-in-8bit", is_flag=True, help="Load model in 8-bit mode")
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def inference(
config: str,
accelerate: bool,
lora_model_dir: Optional[str] = None,
base_model: Optional[str] = None,
**kwargs,
):
"""Run inference with a trained model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
del kwargs["inference"] # interferes with inference.do_cli
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if base_model:
kwargs["output_dir"] = base_model
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.inference import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=False,
help="Use accelerate launch for multi-GPU operations",
)
@click.option(
"--model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing model weights to shard",
)
@click.option(
"--save-dir",
type=click.Path(path_type=str),
help="Directory to save sharded weights",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def shard(config: str, accelerate: bool, **kwargs):
"""Shard model weights."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.shard"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.shard import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for weight merging",
)
@click.option(
"--model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing sharded weights",
)
@click.option(
"--save-path", type=click.Path(path_type=str), help="Path to save merged weights"
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs):
"""Merge sharded FSDP model weights."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if accelerate:
base_cmd = [
"accelerate",
"launch",
"-m",
"axolotl.cli.merge_sharded_fsdp_weights",
]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.merge_sharded_fsdp_weights import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--lora-model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing the LoRA model to merge",
)
@click.option(
"--output-dir",
type=click.Path(path_type=str),
help="Directory to save the merged model",
)
def merge_lora(
config: str,
lora_model_dir: Optional[str] = None,
output_dir: Optional[str] = None,
):
"""Merge a trained LoRA into a base model"""
kwargs = {}
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if output_dir:
kwargs["output_dir"] = output_dir
from axolotl.cli.merge_lora import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.option("--dest", help="Destination directory")
def fetch(directory: str, dest: Optional[str]):
"""
Fetch example configs or other resources.
Available directories:
- examples: Example configuration files
- deepspeed_configs: DeepSpeed configuration files
"""
fetch_from_github(f"{directory}/", dest)
def main():
cli()
if __name__ == "__main__":
main()

View File

@@ -2,7 +2,6 @@
CLI to run merge a trained LoRA into a base model
"""
from pathlib import Path
from typing import Union
import fire
import transformers
@@ -12,7 +11,7 @@ from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs))

View File

@@ -177,7 +177,7 @@ def merge_fsdp_weights(
state.wait_for_everyone()
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs))

View File

@@ -1,218 +0,0 @@
"""Utility methods for axoltl CLI."""
import concurrent.futures
import dataclasses
import hashlib
import json
import logging
from pathlib import Path
from types import NoneType
from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin
import click
import requests
from pydantic import BaseModel
LOG = logging.getLogger("axolotl.cli.utils")
def add_options_from_dataclass(config_class: Type[Any]):
"""Create Click options from the fields of a dataclass."""
def decorator(function):
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = field.type
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
if field_type == bool:
field_name = field.name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
option_name,
default=field.default,
help=field.metadata.get("description"),
)(function)
else:
option_name = f"--{field.name.replace('_', '-')}"
function = click.option(
option_name,
type=field_type,
default=field.default,
help=field.metadata.get("description"),
)(function)
return function
return decorator
def add_options_from_config(config_class: Type[BaseModel]):
"""Create Click options from the fields of a Pydantic model."""
def decorator(function):
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
if field.annotation == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
option_name, default=None, help=field.description
)(function)
else:
option_name = f"--{name.replace('_', '-')}"
function = click.option(
option_name, default=None, help=field.description
)(function)
return function
return decorator
def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
"""Build command list from base command and options."""
cmd = base_cmd.copy()
for key, value in options.items():
if value is None:
continue
key = key.replace("_", "-")
if isinstance(value, bool):
if value:
cmd.append(f"--{key}")
else:
cmd.extend([f"--{key}", str(value)])
return cmd
def download_file(
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
) -> Tuple[str, str]:
"""
Download a single file and return its processing status.
Args:
file_info: Tuple of (file_path, remote_sha)
raw_base_url: Base URL for raw GitHub content
dest_path: Local destination directory
dir_prefix: Directory prefix to filter files
Returns:
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'
"""
file_path, remote_sha = file_info
raw_url = f"{raw_base_url}/{file_path}"
dest_file = dest_path / file_path.split(dir_prefix)[-1]
# Check if file exists and needs updating
if dest_file.exists():
with open(dest_file, "rb") as file:
content = file.read()
# Calculate git blob SHA
blob = b"blob " + str(len(content)).encode() + b"\0" + content
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
if local_sha == remote_sha:
print(f"Skipping {file_path} (unchanged)")
return file_path, "unchanged"
print(f"Updating {file_path}")
status = "new"
else:
print(f"Downloading {file_path}")
status = "new"
# Create directories if needed
dest_file.parent.mkdir(parents=True, exist_ok=True)
# Download and save file
try:
response = requests.get(raw_url, timeout=30)
response.raise_for_status()
with open(dest_file, "wb") as file:
file.write(response.content)
return file_path, status
except (requests.RequestException, IOError) as request_error:
print(f"Error downloading {file_path}: {str(request_error)}")
return file_path, "error"
def fetch_from_github(
dir_prefix: str, dest_dir: Optional[str] = None, max_workers: int = 5
) -> None:
"""
Sync files from a specific directory in the GitHub repository.
Only downloads files that don't exist locally or have changed.
Args:
dir_prefix: Directory prefix to filter files (e.g., 'examples/', 'deepspeed_configs/')
dest_dir: Local destination directory
max_workers: Maximum number of concurrent downloads
"""
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
# Get repository tree with timeout
response = requests.get(api_url, timeout=30)
response.raise_for_status()
tree = json.loads(response.text)
# Filter for files and get their SHA
files = {
item["path"]: item["sha"]
for item in tree["tree"]
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
}
if not files:
raise click.ClickException(f"No files found in {dir_prefix}")
# Default destination directory is the last part of dir_prefix
default_dest = Path(dir_prefix.rstrip("/"))
dest_path = Path(dest_dir) if dest_dir else default_dest
# Keep track of processed files for summary
files_processed: Dict[str, List[str]] = {
"new": [],
"updated": [],
"unchanged": [],
"error": [],
}
# Process files in parallel using ThreadPoolExecutor
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_file = {
executor.submit(
download_file,
(file_path, remote_sha),
raw_base_url,
dest_path,
dir_prefix,
): file_path
for file_path, remote_sha in files.items()
}
# Process completed tasks as they finish
for future in concurrent.futures.as_completed(future_to_file):
file_path = future_to_file[future]
try:
file_path, status = future.result()
files_processed[status].append(file_path)
except (requests.RequestException, IOError) as request_error:
print(f"Error processing {file_path}: {str(request_error)}")
files_processed["error"].append(file_path)
# Log summary
LOG.info("\nSync Summary:")
LOG.info(f"New files: {len(files_processed['new'])}")
LOG.info(f"Updated files: {len(files_processed['updated'])}")
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
if files_processed["error"]:
LOG.info(f"Failed files: {len(files_processed['error'])}")

View File

@@ -3,88 +3,36 @@ helper functions for fixing the embeddings/tokenizer
"""
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
# GNU LESSER GENERAL PUBLIC LICENSE
# Version 3, 29 June 2007
#
# Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
# Everyone is permitted to copy and distribute verbatim copies
# of this license document, but changing it is not allowed.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import itertools
import logging
from collections import Counter
import datasets
import numpy as np
import torch
LOG = logging.getLogger("axolotl.core.tokenizer_utils")
@torch.inference_mode()
def fix_untrained_tokens( # pylint: disable=too-many-return-statements
model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16
):
@torch.inference_mode
def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
"""
Llama-3 for eg has untrained vectors in the base model.
These include <|eot_id|>, <|start_header_id|>, <|end_header_id|>
We reset them to the mean of the rest of the tokens
Many of the newer models have reserved tokens that are not trained.
"""
# Code licensed under LGPL
embedding_matrix = model.get_input_embeddings().weight
lm_head_matrix = model.get_output_embeddings().weight
chat_template = getattr(tokenizer, "chat_template", None)
tokenizer = tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer
# Ignore some model checks for now
if not ignored_tokenizer_names:
ignored_tokenizer_names = []
if (
model.config._name_or_path # pylint: disable=protected-access
in ignored_tokenizer_names
):
return
# Sometimes the sizes can be different like in vision models
# Ie <image> is in input, but not in output
min_size = min(embedding_matrix.shape[1], lm_head_matrix.shape[1])
embedding_matrix = embedding_matrix[:, :min_size]
lm_head_matrix = lm_head_matrix[:, :min_size]
# Get untrained tokens
indicator_untrained1 = torch.amax(embedding_matrix, axis=1) <= eps
# Check lm_head as well
# Does NOT work for Llama 3.1!!
indicator_untrained2 = torch.amax(lm_head_matrix, axis=1) <= eps
# We instead check for repeated vectors
lm_head_where = torch.where(indicator_untrained1)[0]
lm_head_bad = lm_head_matrix[lm_head_where]
lm_head_bad = lm_head_bad.cpu().float().numpy().round(3)
counter = Counter()
for row in lm_head_bad:
counter[hash(row.data.tobytes())] += 1
counter = Counter({k: c for k, c in counter.items() if c >= 2})
lm_head_where = lm_head_where.cpu().numpy()
final_bad_lm_head = []
for j, row in enumerate(lm_head_bad):
if hash(row.data.tobytes()) in counter:
final_bad_lm_head.append(lm_head_where[j])
indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained2)
indicator_untrained2[final_bad_lm_head] = True
# Combine both checks
indicator_untrained = indicator_untrained1 & indicator_untrained2
# Remove pad token possibility
if hasattr(tokenizer, "pad_token_id"):
pad_token_id = tokenizer.pad_token_id
if pad_token_id is not None and pad_token_id < indicator_untrained.shape[0]:
indicator_untrained[pad_token_id] = False
indicator_untrained = torch.amax(embedding_matrix, axis=1) <= eps
where_untrained = torch.where(indicator_untrained)[0]
n_untrained = where_untrained.shape[0]
n_trained = embedding_matrix.shape[0] - n_untrained
@@ -92,9 +40,10 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements
# Get set and actual tokens
where_untrained = where_untrained.tolist()
if len(where_untrained) == 0:
return
return False
# Remove untrained indices where it's longer
where_untrained_set = frozenset(where_untrained)
actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained)
# Remove None items in actual_bad_tokens
@@ -104,14 +53,10 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements
if_bad_first = False
if_bad_second = False
# Check tokenizer's chat template for any untrained tokens
chat_template = getattr(tokenizer, "chat_template", None)
if chat_template is not None:
if_bad_first = any(x in chat_template for x in actual_bad_tokens)
if isinstance(train_dataset, datasets.IterableDataset):
# Skip the check, since the code below assumes
# an indexable dataset
return
# Check the first 250, last 250 input_ids
size_dataset = len(train_dataset)
size = min(size_dataset, 250)
@@ -138,69 +83,7 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements
# Check if bad tokens exists!
if not if_bad_first and not if_bad_second:
return
# Check if lm_head / embed_token are trainable!
bad_not_trainable = False
if not embedding_matrix.requires_grad:
bad_not_trainable = True
if not lm_head_matrix.requires_grad:
bad_not_trainable = True
if bad_not_trainable: # pylint: disable=too-many-nested-blocks
final_bad_items = []
# Re-check the first 250, last 250 input_ids
size_dataset = len(train_dataset)
size = min(size_dataset, 250)
for j in range(size):
input_ids = train_dataset[j]
if "input_ids" in input_ids:
input_ids = input_ids["input_ids"]
for item in input_ids:
if item in where_untrained_set:
final_bad_items.append(item)
# Re-check last 250
left = max(size_dataset - 250, 0)
for j in range(left, size_dataset):
input_ids = train_dataset[j]
if "input_ids" in input_ids:
input_ids = input_ids["input_ids"]
for item in input_ids:
if item in where_untrained_set:
final_bad_items.append(item)
# If no bad tokens, possibly chat template itself has issues?
if len(final_bad_items) == 0:
# Recheck 2000 and last 2000 items
size_dataset = len(train_dataset)
size = min(size_dataset, 2000)
for j in range(size):
input_ids = train_dataset[j]
if "input_ids" in input_ids:
input_ids = input_ids["input_ids"]
for item in input_ids:
if item in where_untrained_set:
final_bad_items.append(item)
# Re-check last 2000
left = max(size_dataset - 2000, 0)
for j in range(left, size_dataset):
input_ids = train_dataset[j]
if "input_ids" in input_ids:
input_ids = input_ids["input_ids"]
for item in input_ids:
if item in where_untrained_set:
final_bad_items.append(item)
# Most likely false signal!
if len(final_bad_items) == 0:
return
raise ValueError(
f"Untrained tokens of [{list(set(final_bad_items))}] found, but embed_tokens & lm_head not trainable, causing NaNs. "
)
return False
# Count all the possible bad tokens
final_counts = np.zeros(
@@ -214,23 +97,6 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements
train_dataset.map(mapping, batched=True, desc="Counting untrained tokens")
# Get counts for untrained tokens
counts_untrained = final_counts[where_untrained]
# Identify untrained tokens seen in train_dataset
indices_seen_in_train = np.where(counts_untrained > 0)[0]
tokens_to_update = [where_untrained[i] for i in indices_seen_in_train]
if len(tokens_to_update) == 0:
LOG.info(
"No untrained tokens found in train_dataset. No embeddings were modified."
)
return
# Log the token IDs that are being rescaled
LOG.info(
f"Rescaling embeddings for tokens seen in train_dataset: {tokens_to_update}"
)
# Get sum of all items
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
@@ -247,26 +113,38 @@ def fix_untrained_tokens( # pylint: disable=too-many-return-statements
mean_embedding = sum_embedding / n_trained
mean_lm_head = sum_lm_head / n_trained
# Compute scaling for tokens to update
scaling = counts_untrained[indices_seen_in_train] / max(final_counts.max(), 1)
# Scale each to be equal to 1/max_frequency. Also set some to 0 if none seen
scaling = final_counts[where_untrained] / max(final_counts.max(), 1)
scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1)
mean_embedding = (
mean_embedding.repeat(
(
n_untrained,
1,
)
)
* scaling
)
mean_lm_head = (
mean_lm_head.repeat(
(
n_untrained,
1,
)
)
* scaling
)
where_null = scaling.ravel() == 0
mean_embedding[where_null] = 0
mean_lm_head[where_null] = 0
# Prepare mean embeddings for tokens to update
mean_embedding_repeated = (
mean_embedding.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
)
mean_lm_head_repeated = (
mean_lm_head.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
)
# Update embeddings only for tokens seen in train_dataset
embedding_matrix[tokens_to_update] = mean_embedding_repeated.to(
embedding_matrix.dtype
)
lm_head_matrix[tokens_to_update] = mean_lm_head_repeated.to(lm_head_matrix.dtype)
# Set them to the mean
embedding_matrix[where_untrained] = mean_embedding.to(embedding_matrix.dtype)
lm_head_matrix[where_untrained] = mean_lm_head.to(lm_head_matrix.dtype)
# Clean up
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return
return True

View File

@@ -22,7 +22,6 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch
import transformers
from datasets import Dataset
from packaging import version
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
@@ -958,15 +957,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
return res
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
def log(self, logs: Dict[str, float]) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
start_time (`Optional[float]`):
The start of training.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
@@ -974,13 +971,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
try:
return super().log(logs, start_time)
except TypeError:
return super().log(logs) # transformers<=4.46
return super().log(logs) # transformers<=4.46
return super().log(logs)
def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
@@ -1164,22 +1155,6 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
torch.cuda.empty_cache()
return loss
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
@@ -1188,22 +1163,6 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
tag_names = ["axolotl", "orpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
@@ -1212,49 +1171,6 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# train metrics should have no prefix, eval should have 'eval_'
prefix = "eval_" if train_eval == "eval" else ""
# accumulate average metrics from sums and lengths
for split in ["chosen", "rejected"]:
if f"count/{split}" in self._stored_metrics[train_eval]:
count_sum = (
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
.sum()
.item()
)
for metric in ["rewards", "logps", "logits"]:
logs[f"{prefix}{metric}/{split}"] = (
torch.Tensor(
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
)
.sum()
.item()
/ count_sum
)
# delete obsolete metric
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
del self._stored_metrics[train_eval][f"count/{split}"]
# calculate reward margin
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
logs[f"{prefix}rewards/margins"] = (
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
)
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
@@ -1263,22 +1179,6 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
tag_names = ["axolotl", "cpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
@@ -1287,15 +1187,6 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
class TrainerBuilderBase(abc.ABC):
"""
@@ -2006,6 +1897,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator = MultiModalChatDataCollator
kwargs["processor"] = self.processor
kwargs["chat_template"] = training_args.chat_template
kwargs["chat_template_type"] = self.cfg.chat_template
else:
collator = DataCollatorForSeq2Seq

View File

@@ -33,7 +33,7 @@ class CutCrossEntropyArgs(BaseModel):
@model_validator(mode="before")
@classmethod
def check_dtype_is_half(cls, data):
if data.get("cut_cross_entropy") and not (data.get("bf16") or data.get("fp16")):
if not (data.get("bf16") or data.get("fp16")):
raise ValueError(
"Cut Cross Entropy requires fp16/bf16 training for backward pass. "
"Please set `bf16` or `fp16` to `True`."

View File

@@ -1,80 +0,0 @@
"""
fix for FSDP optimizer save in trainer w 4.47.0
"""
import inspect
import logging
from transformers import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")
ORIGINAL_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
"""
PATCHED_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
training_loop = get_training_loop_code()
training_loop, _ = detab_code(training_loop)
return ORIGINAL_TRAINER_CODE in training_loop
def patch_training_loop_for_fsdp():
"""
monkeypatch for fixing the training loop for fsdp with optimizer save
"""
try:
training_loop = get_training_loop_code()
except OSError:
return
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
training_loop
)
training_loop, _ = detab_code(training_loop)
if ORIGINAL_TRAINER_CODE not in training_loop:
return
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
training_loop = training_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -1,206 +0,0 @@
"""
fix for FSDP gradient accumulation
see https://github.com/huggingface/transformers/pull/35128
"""
import inspect
import logging
from transformers import LlamaForCausalLM, Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
"""
PATCHED_CONTEXT_CODE = """
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
else:
loss = self.compute_loss(model, inputs)
"""
ORIGINAL_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
"""
PATCHED_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
"""
def get_training_step_code() -> str:
training_step = inspect.getsource(
Trainer.training_step # pylint: disable=protected-access
)
return training_step
def check_training_step_is_patchable() -> bool:
training_step = get_training_step_code()
training_step, _ = detab_code(training_step)
return ORIGINAL_CONTEXT_CODE in training_step
def patch_training_step_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
training_step = get_training_step_code()
except OSError:
return
Trainer._original_training_step = training_step # pylint: disable=protected-access
training_step, _ = detab_code(training_step)
if ORIGINAL_CONTEXT_CODE not in training_step:
return
# assert (
# ORIGINAL_CONTEXT_CODE in training_step
# ), "Original training_step code not found"
training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
training_step = training_step.replace(
"def training_step(",
"def _fixed_training_step(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_step:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching training_step")
Trainer.training_step = ( # pylint: disable=protected-access
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
)
def get_model_forward_code() -> str:
forward = inspect.getsource(
LlamaForCausalLM.forward # pylint: disable=protected-access
)
return forward
def check_forward_is_patchable() -> bool:
forward = get_model_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_LLAMA_FCLM_CODE in forward
def patch_forward_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
forward = get_model_forward_code()
except OSError:
return
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
if ORIGINAL_LLAMA_FCLM_CODE not in forward:
return
# assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"
forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE)
forward = forward.replace(
"def forward(",
"def _fixed_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching forward")
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -9,7 +9,10 @@ import torch
from accelerate.logging import get_logger
from peft import PeftModelForCausalLM
from torch import nn
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from transformers.models.llama.modeling_llama import (
LlamaFlashAttention2,
LlamaForCausalLM,
)
LOG = get_logger("axolotl.monkeypatch.unsloth")
@@ -52,6 +55,11 @@ def original_apply_o(self, hidden_states):
return attn_output
def get_forward_code() -> str:
forward = inspect.getsource(LlamaForCausalLM.forward)
return forward
def get_self_attn_code() -> str:
forward = inspect.getsource(LlamaFlashAttention2.forward)
return forward
@@ -94,22 +102,12 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
return code, spaces
self_attn_lora_patched = False # pylint: disable=invalid-name
def patch_self_attn_lora():
global self_attn_lora_patched # pylint: disable=global-statement
if self_attn_lora_patched:
# prevent patching multiple times
return
self_attn_forward = get_self_attn_code()
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
self_attn_forward
@@ -141,7 +139,6 @@ def patch_self_attn_lora():
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
self_attn_lora_patched = True
LOG.info("patching unsloth attn lora", main_process_only=True)
LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821

File diff suppressed because one or more lines are too long

View File

@@ -22,6 +22,7 @@ class MultiModalChatDataCollator(DataCollatorMixin):
processor: ProcessorMixin
return_tensors: str = "pt"
chat_template: Optional[str] = None
chat_template_type: Optional[str] = None
packing: bool = False
max_images: int = -1
padding: Union[bool, str, PaddingStrategy] = True
@@ -35,142 +36,187 @@ class MultiModalChatDataCollator(DataCollatorMixin):
self, examples: list[Union[list[int], Any, dict[str, Any]]]
) -> dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor.
return self.__class__.process_rows(
examples, self.processor, self.chat_template, self.max_images
examples,
self.processor,
self.chat_template,
self.max_images,
chat_template_type=self.chat_template_type,
)
@staticmethod
def process_rows(examples, processor, chat_template, max_images, length_only=False):
def preprocess(examples: list[dict]) -> list[dict]:
"""
Preprocess conversation examples to ensure consistent format.
Converts different conversation formats to OpenAI format with 'messages'.
Supports two formats:
1. OpenAI format with 'messages'
2. Legacy format with 'conversations'
Args:
examples: list of conversation dictionaries
Returns:
dict in OpenAI format with 'messages' key
Raises:
ValueError: If the conversation format is not supported
"""
role_mapping = {
"human": "user",
"gpt": "assistant",
}
def normalize_role(role: str) -> str:
"""Normalize role names to OpenAI format. Default to original role if not found."""
return role_mapping.get(role, role)
def convert_legacy_format(example: dict) -> dict:
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
messages = [
{
"role": normalize_role(convo["from"]),
"content": convo["value"],
}
for convo in example["conversations"]
]
# Create new dict without 'conversations' key
result = deepcopy(example)
result.pop("conversations")
return {"messages": messages, **result}
processed_examples = []
for example in examples:
# OpenAI format
if "messages" in example:
processed_examples.append(example)
# Legacy format
elif "conversations" in example:
processed_examples.append(convert_legacy_format(example))
else:
raise ValueError(
"Only `messages` and `conversations` message keys are currently supported."
)
return processed_examples
@staticmethod
def process_images(examples, max_images):
"""
Process images from examples, ensuring consistency in image presence and applying max_images limit.
Args:
examples: List of dictionaries that may contain 'images' key
max_images: Maximum number of images to keep per example (0 means no limit)
Returns:
Either None (if no images) or List[Image objects] (if all examples have images)
Raises:
ValueError: If there's a mix of None and non-None images
"""
def get_image(example):
if "images" not in example:
return None
images = example["images"]
if isinstance(images, str):
return Image.open(images)
return images
images = [get_image(example) for example in examples]
# Count None and non-None images
none_count = sum(1 for img in images if img is None)
# All images are None
if none_count == len(images):
return None
# Mix of None and non-None images
if none_count > 0:
raise ValueError(
"All images should be either None or not None. "
"Please provide images for all examples or None."
)
# Apply max_images limit if specified
if max_images > 0:
images = [
(
img_batch[:max_images]
if isinstance(img_batch, (list, tuple))
else img_batch
)
for img_batch in images
]
return images
@staticmethod
def pixtral_chat_conversion(messages):
is_single_message = not isinstance(messages, list)
if is_single_message:
messages = [messages]
for i, message in enumerate(messages):
if message["role"] == "user":
for j, content in enumerate(message["content"]):
if "type" in content and content["type"] == "text":
messages[i]["content"][j] = {
"type": "text",
"content": content["text"],
}
if message["role"] == "assistant":
messages[i]["content"] = message["content"][0]["text"]
if is_single_message:
return messages[0]
return messages
@staticmethod
def process_rows(
examples,
processor,
chat_template,
max_images,
length_only=False,
chat_template_type=None,
):
# HINT: use `_torch_collate_batch` to stack and pad tensors
# see also DataCollatorWithFlattening and DefaultDataCollator
# *** This is COPIED from the trl example sft_vlm.py code ***
# use this as a starting point
def _preprocess(examples: list[dict]) -> list[dict]:
"""
Preprocess conversation examples to ensure consistent format.
Converts different conversation formats to OpenAI format with 'messages'.
Supports two formats:
1. OpenAI format with 'messages'
2. Legacy format with 'conversations'
Args:
examples: list of conversation dictionaries
Returns:
dict in OpenAI format with 'messages' key
Raises:
ValueError: If the conversation format is not supported
"""
role_mapping = {
"human": "user",
"gpt": "assistant",
}
def normalize_role(role: str) -> str:
"""Normalize role names to OpenAI format. Default to original role if not found."""
return role_mapping.get(role, role)
def convert_legacy_format(example: dict) -> dict:
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
messages = [
{
"role": normalize_role(convo["from"]),
"content": convo["value"],
}
for convo in example["conversations"]
]
# Create new dict without 'conversations' key
result = deepcopy(example)
result.pop("conversations")
return {"messages": messages, **result}
processed_examples = []
for example in examples:
# OpenAI format
if "messages" in example:
processed_examples.append(example)
# Legacy format
elif "conversations" in example:
processed_examples.append(convert_legacy_format(example))
else:
raise ValueError(
"Only `messages` and `conversations` message keys are currently supported."
)
return processed_examples
def _process_images(examples, max_images):
"""
Process images from examples, ensuring consistency in image presence and applying max_images limit.
Args:
examples: List of dictionaries that may contain 'images' key
max_images: Maximum number of images to keep per example (0 means no limit)
Returns:
Either None (if no images) or List[Image objects] (if all examples have images)
Raises:
ValueError: If there's a mix of None and non-None images
"""
def get_image(example):
if "images" not in example:
return None
images = example["images"]
if isinstance(images, str):
return Image.open(images)
return images
images = [get_image(example) for example in examples]
# Count None and non-None images
none_count = sum(1 for img in images if img is None)
# All images are None
if none_count == len(images):
return None
# Mix of None and non-None images
if none_count > 0:
raise ValueError(
"All images should be either None or not None. "
"Please provide images for all examples or None."
)
# Apply max_images limit if specified
if max_images > 0:
images = [
(
img_batch[:max_images]
if isinstance(img_batch, (list, tuple))
else img_batch
)
for img_batch in images
]
return images
# Preprocess the examples
examples = _preprocess(examples)
examples = __class__.preprocess(examples)
# Get the texts and images, and apply the chat template
texts = [
processor.apply_chat_template(
example["messages"], chat_template=chat_template, tokenize=False
)
for example in examples
]
if chat_template_type == "pixtral":
texts = [
processor.apply_chat_template(
__class__.pixtral_chat_conversion(example["messages"]),
chat_template=chat_template,
tokenize=False,
)
for example in examples
]
else:
texts = [
processor.apply_chat_template(
example["messages"], chat_template=chat_template, tokenize=False
)
for example in examples
]
images = _process_images(examples, max_images=max_images)
images = __class__.process_images(examples, max_images=max_images)
if chat_template_type == "llava":
# LLava1.5 does not support multiple images
images = [image[0] for image in images]
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
@@ -179,9 +225,12 @@ class MultiModalChatDataCollator(DataCollatorMixin):
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 #
# Ignore the image token index in the loss computation (model specific)
image_token_id = processor.tokenizer.convert_tokens_to_ids(
processor.image_token
)
if chat_template_type == "qwen2_vl":
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
else:
image_token_id = processor.tokenizer.convert_tokens_to_ids(
processor.image_token
)
labels[labels == image_token_id] = -100
batch["labels"] = labels

View File

@@ -132,7 +132,7 @@ def normalize_config(cfg):
cfg.is_multimodal = (
hasattr(model_config, "model_type")
and model_config.model_type in ["llava", "mllama"]
and model_config.model_type in ["llava", "mllama", "qwen2_vl", "qwen2_5_vl"]
or any(
multimodal_name in cfg.base_model.lower()
for multimodal_name in [
@@ -145,7 +145,12 @@ def normalize_config(cfg):
cfg.processor_config = (
cfg.processor_config or cfg.base_model_config or cfg.base_model
)
model_config = model_config.text_config
try:
model_config = model_config.text_config
except AttributeError:
# for qwen2_vl
model_config = model_config.get_text_config()
cfg.model_config_type = model_config.model_type
@@ -153,7 +158,7 @@ def normalize_config(cfg):
cfg.is_llama_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type in ["llama", "mllama_text_model"]
and model_config.model_type == ["llama", "mllama_text_model"]
)
or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower()

View File

@@ -51,6 +51,7 @@ class ChatTemplate(str, Enum):
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
llava = "llava" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
@@ -60,6 +61,8 @@ class ChatTemplate(str, Enum):
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
exaone = "exaone" # pylint: disable=invalid-name
metharme = "metharme" # pylint: disable=invalid-name
pixtral = "pixtral" # pylint: disable=invalid-name
qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel):
@@ -1521,6 +1524,19 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
return data
@model_validator(mode="before")
@classmethod
def check_hopper_8bit_lora(cls, data):
is_sm_90: bool = (
data["capabilities"]
and data["capabilities"].get("compute_capability") == "sm_90"
)
if data.get("adapter") and data.get("load_in_8bit") and is_sm_90:
# see https://github.com/bitsandbytes-foundation/bitsandbytes/issues/538#issuecomment-2262945464
raise ValueError("8-bit LoRA is not supported on Hopper GPUs")
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_deepspeed(cls, data):

View File

@@ -2,9 +2,11 @@
import functools
import logging
import time
from pathlib import Path
from typing import List, Optional, Tuple, Union
import requests
from datasets import (
Dataset,
DatasetDict,
@@ -42,11 +44,7 @@ from axolotl.prompters import (
UnsupportedPrompter,
)
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
md5,
retry_on_request_exceptions,
)
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.trainer import (
@@ -57,6 +55,27 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl")
def retry_on_request_exceptions(max_retries=3, delay=1):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
) as exc:
if attempt < max_retries - 1:
time.sleep(delay)
else:
raise exc
return wrapper
return decorator
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_dataset(cfg, tokenizer, processor=None):
prompters = []

View File

@@ -1,57 +1,13 @@
"""data handling helpers"""
import functools
import hashlib
import logging
import time
from enum import Enum
import huggingface_hub
import requests
from datasets import Dataset
LOG = logging.getLogger("axolotl")
class RetryStrategy(Enum):
"""
Enum for retry strategies.
"""
CONSTANT = 1
LINEAR = 2
EXPONENTIAL = 3
def retry_on_request_exceptions(
max_retries=3, delay=1, retry_strategy: RetryStrategy = RetryStrategy.LINEAR
):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
huggingface_hub.errors.HfHubHTTPError,
) as exc:
if attempt < max_retries - 1:
if retry_strategy == RetryStrategy.EXPONENTIAL:
step_delay = delay * 2**attempt
elif retry_strategy == RetryStrategy.LINEAR:
step_delay = delay * (attempt + 1)
else:
step_delay = delay # Use constant delay.
time.sleep(step_delay)
else:
raise exc
return wrapper
return decorator
def md5(to_hash: str, encoding: str = "utf-8") -> str:
try:
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()

View File

@@ -30,6 +30,7 @@ from transformers import ( # noqa: F401
AddedToken,
AutoConfig,
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
@@ -91,7 +92,11 @@ def get_module_class_from_name(module, name):
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
if cfg.is_multimodal:
model_config = model_config.text_config
try:
model_config = model_config.text_config
except AttributeError:
# for qwen2_vl
model_config = model_config.get_text_config()
quant_config_exists = (
hasattr(model_config, "quantization_config")
@@ -367,7 +372,11 @@ class ModelLoader:
# init model config
self.model_config = load_model_config(cfg)
if cfg.is_multimodal:
self.text_model_config = self.model_config.text_config
try:
self.text_model_config = self.model_config.text_config
except AttributeError:
# for qwen2_vl
self.text_model_config = self.model_config.get_text_config()
else:
self.text_model_config = self.model_config
@@ -380,28 +389,12 @@ class ModelLoader:
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg)
if self.cfg.fsdp:
from axolotl.monkeypatch.trainer_fsdp_optim import (
patch_training_loop_for_fsdp,
)
patch_training_loop_for_fsdp()
if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if self.cfg.flash_attention:
self.patch_attention()
if self.cfg.model_config_type == "llama":
from axolotl.monkeypatch.trainer_grad_accum import (
patch_forward_for_ga,
patch_training_step_for_ga,
)
patch_forward_for_ga()
patch_training_step_for_ga()
if self.cfg.sample_packing and self.cfg.s2_attention:
raise ValueError(
"Received `sample_packing=true` and `s2_attention=true`; however, \
@@ -413,14 +406,10 @@ class ModelLoader:
and self.cfg.flash_attention
and self.cfg.sample_packing
):
if "auto_map" in self.model_config:
try:
auto_map_config = self.model_config["auto_map"]
except TypeError:
auto_map_config = self.model_config.auto_map
has_remote_code = "AutoModelForCausalLM" in auto_map_config
else:
has_remote_code = False
has_remote_code = (
"auto_map" in self.model_config
and "AutoModelForCausalLM" in self.model_config["auto_map"]
)
if has_remote_code and self.cfg.trust_remote_code is False:
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
has_remote_code = self.cfg.trust_remote_code
@@ -573,6 +562,10 @@ class ModelLoader:
self.AutoModelLoader = ( # pylint: disable=invalid-name
MllamaForConditionalGeneration
)
elif self.model_config.model_type == "qwen2_vl":
self.AutoModelLoader = ( # pylint: disable=invalid-name
AutoModelForImageTextToText
)
else:
self.AutoModelLoader = (
AutoModelForVision2Seq # pylint: disable=invalid-name
@@ -1065,7 +1058,9 @@ class ModelLoader:
and self.model.get_input_embeddings().num_embeddings < embeddings_len
):
resize_kwargs = {}
if self.cfg.mean_resizing_embeddings is not None:
if self.cfg.mean_resizing_embeddings is not None and not (
self.model_config.model_type == "llava"
):
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
else:

View File

View File

@@ -1,36 +0,0 @@
"""Shared pytest fixtures for cli module."""
import pytest
from click.testing import CliRunner
VALID_TEST_CONFIG = """
base_model: HuggingFaceTB/SmolLM2-135M
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
sequence_len: 2048
max_steps: 1
micro_batch_size: 1
gradient_accumulation_steps: 1
learning_rate: 1e-3
special_tokens:
pad_token: <|endoftext|>
"""
@pytest.fixture
def cli_runner():
return CliRunner()
@pytest.fixture
def valid_test_config():
return VALID_TEST_CONFIG
@pytest.fixture
def config_path(tmp_path):
"""Creates a temporary config file"""
path = tmp_path / "config.yml"
path.write_text(VALID_TEST_CONFIG)
return path

View File

@@ -1,38 +0,0 @@
"""pytest tests for axolotl CLI fetch command."""
from unittest.mock import patch
from axolotl.cli.main import fetch
def test_fetch_cli_examples(cli_runner):
"""Test fetch command with examples directory"""
with patch("axolotl.cli.main.fetch_from_github") as mock_fetch:
result = cli_runner.invoke(fetch, ["examples"])
assert result.exit_code == 0
mock_fetch.assert_called_once_with("examples/", None)
def test_fetch_cli_deepspeed(cli_runner):
"""Test fetch command with deepspeed_configs directory"""
with patch("axolotl.cli.main.fetch_from_github") as mock_fetch:
result = cli_runner.invoke(fetch, ["deepspeed_configs"])
assert result.exit_code == 0
mock_fetch.assert_called_once_with("deepspeed_configs/", None)
def test_fetch_cli_with_dest(cli_runner, tmp_path):
"""Test fetch command with custom destination"""
with patch("axolotl.cli.main.fetch_from_github") as mock_fetch:
custom_dir = tmp_path / "tmp_examples"
result = cli_runner.invoke(fetch, ["examples", "--dest", str(custom_dir)])
assert result.exit_code == 0
mock_fetch.assert_called_once_with("examples/", str(custom_dir))
def test_fetch_cli_invalid_directory(cli_runner):
"""Test fetch command with invalid directory choice"""
result = cli_runner.invoke(fetch, ["invalid"])
assert result.exit_code != 0

View File

@@ -1,30 +0,0 @@
"""pytest tests for axolotl CLI inference command."""
from unittest.mock import patch
from axolotl.cli.main import cli
def test_inference_basic(cli_runner, config_path):
"""Test basic inference"""
with patch("axolotl.cli.inference.do_inference") as mock:
result = cli_runner.invoke(
cli,
["inference", str(config_path), "--no-accelerate"],
catch_exceptions=False,
)
assert mock.called
assert result.exit_code == 0
def test_inference_gradio(cli_runner, config_path):
"""Test basic inference (gradio path)"""
with patch("axolotl.cli.inference.do_inference_gradio") as mock:
result = cli_runner.invoke(
cli,
["inference", str(config_path), "--no-accelerate", "--gradio"],
catch_exceptions=False,
)
assert mock.called
assert result.exit_code == 0

View File

@@ -1,47 +0,0 @@
"""General pytest tests for axolotl.cli.main interface."""
from axolotl.cli.main import build_command, cli
def test_build_command():
"""Test converting dict of options to CLI arguments"""
base_cmd = ["accelerate", "launch"]
options = {
"learning_rate": 1e-4,
"batch_size": 8,
"debug": True,
"use_fp16": False,
"null_value": None,
}
result = build_command(base_cmd, options)
assert result == [
"accelerate",
"launch",
"--learning-rate",
"0.0001",
"--batch-size",
"8",
"--debug",
]
def test_invalid_command_options(cli_runner):
"""Test handling of invalid command options"""
result = cli_runner.invoke(
cli,
[
"train",
"config.yml",
"--invalid-option",
"value",
],
)
assert result.exit_code != 0
assert "No such option" in result.output
def test_required_config_argument(cli_runner):
"""Test commands fail properly when config argument is missing"""
result = cli_runner.invoke(cli, ["train"])
assert result.exit_code != 0
assert "Missing argument 'CONFIG'" in result.output

View File

@@ -1,56 +0,0 @@
"""pytest tests for axolotl CLI merge_lora command."""
from unittest.mock import patch
from axolotl.cli.main import cli
def test_merge_lora_basic(cli_runner, config_path):
"""Test basic merge_lora command"""
with patch("axolotl.cli.merge_lora.do_cli") as mock_do_cli:
result = cli_runner.invoke(cli, ["merge-lora", str(config_path)])
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
def test_merge_lora_with_dirs(cli_runner, config_path, tmp_path):
"""Test merge_lora with custom lora and output directories"""
lora_dir = tmp_path / "lora"
output_dir = tmp_path / "output"
lora_dir.mkdir()
with patch("axolotl.cli.merge_lora.do_cli") as mock_do_cli:
result = cli_runner.invoke(
cli,
[
"merge-lora",
str(config_path),
"--lora-model-dir",
str(lora_dir),
"--output-dir",
str(output_dir),
],
)
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
assert mock_do_cli.call_args.kwargs["lora_model_dir"] == str(lora_dir)
assert mock_do_cli.call_args.kwargs["output_dir"] == str(output_dir)
def test_merge_lora_nonexistent_config(cli_runner, tmp_path):
"""Test merge_lora with nonexistent config"""
config_path = tmp_path / "nonexistent.yml"
result = cli_runner.invoke(cli, ["merge-lora", str(config_path)])
assert result.exit_code != 0
def test_merge_lora_nonexistent_lora_dir(cli_runner, config_path, tmp_path):
"""Test merge_lora with nonexistent lora directory"""
lora_dir = tmp_path / "nonexistent"
result = cli_runner.invoke(
cli, ["merge-lora", str(config_path), "--lora-model-dir", str(lora_dir)]
)
assert result.exit_code != 0

View File

@@ -1,60 +0,0 @@
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):
"""Test merge_sharded_fsdp_weights command without accelerate"""
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli, ["merge-sharded-fsdp-weights", str(config_path), "--no-accelerate"]
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert result.exit_code == 0
def test_merge_sharded_fsdp_weights_with_model_dir(cli_runner, config_path, tmp_path):
"""Test merge_sharded_fsdp_weights command with model_dir option"""
model_dir = tmp_path / "model"
model_dir.mkdir()
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--no-accelerate",
"--model-dir",
str(model_dir),
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
assert result.exit_code == 0
def test_merge_sharded_fsdp_weights_with_save_path(cli_runner, config_path):
"""Test merge_sharded_fsdp_weights command with save_path option"""
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--no-accelerate",
"--save-path",
"/path/to/save",
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["save_path"] == "/path/to/save"
assert result.exit_code == 0

View File

@@ -1,71 +0,0 @@
"""pytest tests for axolotl CLI preprocess command."""
import shutil
from pathlib import Path
from unittest.mock import patch
import pytest
from axolotl.cli.main import cli
@pytest.fixture(autouse=True)
def cleanup_last_run_prepared():
yield
if Path("last_run_prepared").exists():
shutil.rmtree("last_run_prepared")
def test_preprocess_config_not_found(cli_runner):
"""Test preprocess fails when config not found"""
result = cli_runner.invoke(cli, ["preprocess", "nonexistent.yml"])
assert result.exit_code != 0
def test_preprocess_basic(cli_runner, config_path):
"""Test basic preprocessing with minimal config"""
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
result = cli_runner.invoke(cli, ["preprocess", str(config_path)])
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
assert mock_do_cli.call_args.kwargs["download"] is True
def test_preprocess_without_download(cli_runner, config_path):
"""Test preprocessing without model download"""
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
result = cli_runner.invoke(
cli, ["preprocess", str(config_path), "--no-download"]
)
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
assert mock_do_cli.call_args.kwargs["download"] is False
def test_preprocess_custom_path(cli_runner, tmp_path, valid_test_config):
"""Test preprocessing with custom dataset path"""
config_path = tmp_path / "config.yml"
custom_path = tmp_path / "custom_prepared"
config_path.write_text(valid_test_config)
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
result = cli_runner.invoke(
cli,
[
"preprocess",
str(config_path),
"--dataset-prepared-path",
str(custom_path.absolute()),
],
)
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
assert mock_do_cli.call_args.kwargs["dataset_prepared_path"] == str(
custom_path.absolute()
)

View File

@@ -1,76 +0,0 @@
"""pytest tests for axolotl CLI shard command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
def test_shard_with_accelerate(cli_runner, config_path):
"""Test shard command with accelerate"""
with patch("subprocess.run") as mock:
result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"])
assert mock.called
assert mock.call_args.args[0] == [
"accelerate",
"launch",
"-m",
"axolotl.cli.shard",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0
def test_shard_no_accelerate(cli_runner, config_path):
"""Test shard command without accelerate"""
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(cli, ["shard", str(config_path), "--no-accelerate"])
assert mock.called
assert result.exit_code == 0
def test_shard_with_model_dir(cli_runner, config_path, tmp_path):
"""Test shard command with model_dir option"""
model_dir = tmp_path / "model"
model_dir.mkdir()
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"shard",
str(config_path),
"--no-accelerate",
"--model-dir",
str(model_dir),
],
catch_exceptions=False,
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
assert result.exit_code == 0
def test_shard_with_save_dir(cli_runner, config_path):
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"shard",
str(config_path),
"--no-accelerate",
"--save-dir",
"/path/to/save",
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["save_dir"] == "/path/to/save"
assert result.exit_code == 0

View File

@@ -1,98 +0,0 @@
"""pytest tests for axolotl CLI train command."""
from unittest.mock import MagicMock, patch
from axolotl.cli.main import cli
def test_train_cli_validation(cli_runner):
"""Test CLI validation"""
# Test missing config file
result = cli_runner.invoke(cli, ["train", "--no-accelerate"])
assert result.exit_code != 0
# Test non-existent config file
result = cli_runner.invoke(cli, ["train", "nonexistent.yml", "--no-accelerate"])
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
def test_train_basic_execution(cli_runner, tmp_path, valid_test_config):
"""Test basic successful execution"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock:
result = cli_runner.invoke(cli, ["train", str(config_path)])
assert mock.called
assert mock.call_args.args[0] == [
"accelerate",
"launch",
"-m",
"axolotl.cli.train",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0
def test_train_basic_execution_no_accelerate(cli_runner, tmp_path, valid_test_config):
"""Test basic successful execution"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock())
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--learning-rate",
"1e-4",
"--micro-batch-size",
"2",
"--no-accelerate",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_train.assert_called_once()
def test_train_cli_overrides(cli_runner, tmp_path, valid_test_config):
"""Test CLI arguments properly override config values"""
config_path = tmp_path / "config.yml"
output_dir = tmp_path / "model-out"
test_config = valid_test_config.replace(
"output_dir: model-out", f"output_dir: {output_dir}"
)
config_path.write_text(test_config)
with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock())
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--learning-rate",
"1e-4",
"--micro-batch-size",
"2",
"--no-accelerate",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_train.assert_called_once()
cfg = mock_train.call_args[1]["cfg"]
assert cfg["learning_rate"] == 1e-4
assert cfg["micro_batch_size"] == 2

View File

@@ -1,72 +0,0 @@
"""pytest tests for axolotl CLI utils."""
# pylint: disable=redefined-outer-name
import json
from unittest.mock import Mock, patch
import click
import pytest
import requests
from axolotl.cli.utils import fetch_from_github
# Sample GitHub API response
MOCK_TREE_RESPONSE = {
"tree": [
{"path": "examples/config1.yml", "type": "blob", "sha": "abc123"},
{"path": "examples/config2.yml", "type": "blob", "sha": "def456"},
{"path": "other/file.txt", "type": "blob", "sha": "xyz789"},
]
}
@pytest.fixture
def mock_responses():
"""Mock responses for API and file downloads"""
def mock_get(url, timeout=None): # pylint: disable=unused-argument
response = Mock()
if "api.github.com" in url:
response.text = json.dumps(MOCK_TREE_RESPONSE)
else:
response.content = b"file content"
return response
return mock_get
def test_fetch_from_github_new_files(tmp_path, mock_responses):
"""Test fetching new files"""
with patch("requests.get", mock_responses):
fetch_from_github("examples/", tmp_path)
# Verify files were created
assert (tmp_path / "config1.yml").exists()
assert (tmp_path / "config2.yml").exists()
assert not (tmp_path / "file.txt").exists()
def test_fetch_from_github_unchanged_files(tmp_path, mock_responses):
"""Test handling of unchanged files"""
# Create existing file with matching SHA
existing_file = tmp_path / "config1.yml"
existing_file.write_bytes(b"file content")
with patch("requests.get", mock_responses):
fetch_from_github("examples/", tmp_path)
# File should not be downloaded again
assert existing_file.read_bytes() == b"file content"
def test_fetch_from_github_invalid_prefix(mock_responses):
"""Test error handling for invalid directory prefix"""
with patch("requests.get", mock_responses):
with pytest.raises(click.ClickException):
fetch_from_github("nonexistent/", None)
def test_fetch_from_github_network_error():
"""Test handling of network errors"""
with patch("requests.get", side_effect=requests.RequestException):
with pytest.raises(requests.RequestException):
fetch_from_github("examples/", None)

View File

@@ -1,109 +1,68 @@
"""
shared pytest fixtures
"""
import functools
import importlib
import shutil
import sys
import tempfile
import time
import pytest
import requests
from huggingface_hub import snapshot_download
def retry_on_request_exceptions(max_retries=3, delay=1):
# pylint: disable=duplicate-code
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
) as exc:
if attempt < max_retries - 1:
time.sleep(delay)
else:
raise exc
return wrapper
return decorator
@retry_on_request_exceptions(max_retries=3, delay=5)
def snapshot_download_w_retry(*args, **kwargs):
return snapshot_download(*args, **kwargs)
@pytest.fixture(scope="session", autouse=True)
def download_smollm2_135m_model():
# download the model
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M")
snapshot_download("HuggingFaceTB/SmolLM2-135M")
@pytest.fixture(scope="session", autouse=True)
def download_llama_68m_random_model():
# download the model
snapshot_download_w_retry("JackFram/llama-68m")
snapshot_download("JackFram/llama-68m")
@pytest.fixture(scope="session", autouse=True)
def download_qwen_2_5_half_billion_model():
# download the model
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B")
snapshot_download("Qwen/Qwen2.5-0.5B")
@pytest.fixture(scope="session", autouse=True)
def download_tatsu_lab_alpaca_dataset():
# download the dataset
snapshot_download_w_retry("tatsu-lab/alpaca", repo_type="dataset")
snapshot_download("tatsu-lab/alpaca", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_mhenrichsen_alpaca_2k_dataset():
# download the dataset
snapshot_download_w_retry("mhenrichsen/alpaca_2k_test", repo_type="dataset")
snapshot_download("mhenrichsen/alpaca_2k_test", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_mhenrichsen_alpaca_2k_w_revision_dataset():
# download the dataset
snapshot_download_w_retry(
snapshot_download(
"mhenrichsen/alpaca_2k_test", repo_type="dataset", revision="d05c1cb"
)
@pytest.fixture(scope="session", autouse=True)
def download_mlabonne_finetome_100k_dataset():
# download the dataset
snapshot_download_w_retry("mlabonne/FineTome-100k", repo_type="dataset")
snapshot_download("mlabonne/FineTome-100k", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
@pytest.fixture
def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
# download the dataset
snapshot_download_w_retry(
snapshot_download(
"argilla/distilabel-capybara-dpo-7k-binarized", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True)
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
# download the dataset
snapshot_download_w_retry(
"argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True)
@pytest.fixture
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
# download the dataset
snapshot_download_w_retry(
snapshot_download(
"arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", repo_type="dataset"
)
@@ -115,40 +74,3 @@ def temp_dir():
yield _temp_dir
# Clean up the directory after the test
shutil.rmtree(_temp_dir)
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers import Trainer
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
original_fa2_forward = LlamaFlashAttention2.forward
original_trainer_inner_training_loop = (
Trainer._inner_training_loop # pylint: disable=protected-access
)
original_trainer_training_step = Trainer.training_step
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access
original_trainer_inner_training_loop
)
Trainer.training_step = original_trainer_training_step
# Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [
("transformers",),
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
("transformers.trainer", ["Trainer"]),
("transformers.loss.loss_utils",),
]
for module_name_tuple in modules_to_reset:
module_name = module_name_tuple[0]
module = importlib.import_module(module_name)
sys.modules[module_name] = module
importlib.reload(sys.modules[module_name])
if len(module_name_tuple) > 1:
module_globals = module_name_tuple[1]
for module_global in module_globals:
globals().pop(module_global, None)

View File

@@ -7,11 +7,12 @@ from pathlib import Path
import yaml
from accelerate.test_utils import execute_subprocess_async
from tbparse import SummaryReader
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from ..utils import check_tensorboard
from ..utils import most_recent_subdir
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
@@ -90,8 +91,12 @@ class TestMultiGPUEval:
str(Path(temp_dir) / "config.yaml"),
]
)
check_tensorboard(temp_dir + "/runs", "eval/loss", 2.5, "Eval Loss is too high")
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "eval/loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.5, "Loss is too high"
def test_eval(self, temp_dir):
# pylint: disable=duplicate-code
@@ -159,5 +164,9 @@ class TestMultiGPUEval:
str(Path(temp_dir) / "config.yaml"),
]
)
check_tensorboard(temp_dir + "/runs", "eval/loss", 2.9, "Eval Loss is too high")
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "eval/loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.9, "Loss is too high"

View File

@@ -14,6 +14,8 @@ from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from ..utils import is_hopper
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
@@ -142,6 +144,7 @@ class TestMultiGPULlama:
]
)
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
def test_dpo_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(

View File

@@ -42,7 +42,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"lora_dropout": 0.05,
"lora_target_linear": True,
"sequence_len": 1024,
"val_set_size": 0.02,
"val_set_size": 0.1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
@@ -86,7 +86,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.02,
"val_set_size": 0.1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",

View File

@@ -1,47 +0,0 @@
"""
test cases to make sure the plugin args are loaded from the config file
"""
from pathlib import Path
import yaml
from axolotl.cli import load_cfg
from axolotl.utils.dict import DictDefault
# pylint: disable=duplicate-code
class TestPluginArgs:
"""
test class for plugin args loaded from the config file
"""
def test_liger_plugin_args(self, temp_dir):
test_cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"plugins": ["axolotl.integrations.liger.LigerPlugin"],
"liger_layer_norm": True,
"liger_rope": True,
"liger_rms_norm": False,
"liger_glu_activation": True,
"liger_fused_linear_cross_entropy": True,
}
)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(test_cfg.to_dict()))
cfg = load_cfg(str(Path(temp_dir) / "config.yaml"))
assert cfg.liger_layer_norm is True
assert cfg.liger_rope is True
assert cfg.liger_rms_norm is False
assert cfg.liger_glu_activation is True
assert cfg.liger_fused_linear_cross_entropy is True

View File

@@ -8,6 +8,7 @@ from importlib import reload
from pathlib import Path
import pytest
from tbparse import SummaryReader
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
@@ -16,7 +17,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_tensorboard
from ..utils import most_recent_subdir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -93,6 +94,9 @@ class TestFAXentropyLlama:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high"
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 1.5, "Loss is too high"

View File

@@ -40,7 +40,7 @@ class TestFalconPatched(unittest.TestCase):
"lora_dropout": 0.1,
"lora_target_linear": True,
"lora_modules_to_save": ["word_embeddings", "lm_head"],
"val_set_size": 0.05,
"val_set_size": 0.1,
"special_tokens": {
"bos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>",
@@ -80,7 +80,7 @@ class TestFalconPatched(unittest.TestCase):
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"val_set_size": 0.1,
"special_tokens": {
"bos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>",

View File

@@ -38,7 +38,7 @@ class TestFusedLlama(unittest.TestCase):
"flash_attn_fuse_mlp": True,
"sample_packing": True,
"sequence_len": 1024,
"val_set_size": 0.02,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",

View File

@@ -98,7 +98,7 @@ class TestLoraLlama(unittest.TestCase):
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.02,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",

View File

@@ -39,7 +39,7 @@ class TestMistral(unittest.TestCase):
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
@@ -80,7 +80,7 @@ class TestMistral(unittest.TestCase):
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,
"val_set_size": 0.05,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",

View File

@@ -40,7 +40,7 @@ class TestMixtral(unittest.TestCase):
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"val_set_size": 0.05,
"val_set_size": 0.1,
"special_tokens": {},
"datasets": [
{
@@ -78,7 +78,7 @@ class TestMixtral(unittest.TestCase):
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"val_set_size": 0.1,
"special_tokens": {},
"datasets": [
{

View File

@@ -38,7 +38,7 @@ class TestPhiMultipack(unittest.TestCase):
"pad_to_sequence_len": True,
"load_in_8bit": False,
"adapter": None,
"val_set_size": 0.05,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},

View File

@@ -6,6 +6,8 @@ import os
from pathlib import Path
import pytest
from e2e.utils import most_recent_subdir
from tbparse import SummaryReader
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -13,8 +15,6 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -36,9 +36,6 @@ class TestUnslothQLoRA:
"sequence_len": 1024,
"sample_packing": sample_packing,
"flash_attention": True,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
@@ -76,18 +73,18 @@ class TestUnslothQLoRA:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"sample_packing": False,
"load_in_4bit": True,
"adapter": "qlora",
@@ -126,9 +123,12 @@ class TestUnslothQLoRA:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"
@pytest.mark.parametrize(
"sdp_attention",
@@ -139,9 +139,6 @@ class TestUnslothQLoRA:
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"unsloth_lora_mlp": True,
"unsloth_lora_qkv": True,
"unsloth_lora_o": True,
"sample_packing": False,
"load_in_4bit": True,
"adapter": "qlora",
@@ -181,6 +178,9 @@ class TestUnslothQLoRA:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"

View File

@@ -7,13 +7,15 @@ import os
import unittest
from pathlib import Path
from tbparse import SummaryReader
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_tensorboard, with_temp_dir
from .utils import most_recent_subdir, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -64,9 +66,12 @@ class TestEmbeddingsLrScale(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"
@with_temp_dir
def test_train_w_embedding_lr(self, temp_dir):
@@ -108,6 +113,9 @@ class TestEmbeddingsLrScale(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"

View File

@@ -6,6 +6,7 @@ import logging
import os
import unittest
from tbparse import SummaryReader
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
@@ -14,7 +15,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_tensorboard, with_temp_dir
from .utils import most_recent_subdir, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -65,6 +66,9 @@ class TestPackedLlama(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"

View File

@@ -7,13 +7,15 @@ import os
import unittest
from pathlib import Path
from tbparse import SummaryReader
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_tensorboard, with_temp_dir
from .utils import most_recent_subdir, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -83,6 +85,9 @@ class TestReLoraLlama(unittest.TestCase):
).exists()
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high"
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/grad_norm")] # pylint: disable=invalid-name
assert df.value.values[-1] < 0.2, "grad_norm is too high"

View File

@@ -12,7 +12,6 @@ import torch
# from importlib.metadata import version
from packaging import version
from tbparse import SummaryReader
def with_temp_dir(test_func):
@@ -67,17 +66,3 @@ def require_torch_2_5_1(test_case):
def is_hopper():
compute_capability = torch.cuda.get_device_capability()
return compute_capability == (9, 0)
def check_tensorboard(
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
) -> None:
"""
helper function to parse and check tensorboard logs
"""
tb_log_path = most_recent_subdir(temp_run_dir)
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == tag)] # pylint: disable=invalid-name
assert df.value.values[-1] < lt_val, assertion_err

View File

@@ -1,25 +0,0 @@
""""Test module for checking whether the Hugging Face Transformers is working as expected."""
import unittest
from axolotl.monkeypatch.trainer_grad_accum import (
check_forward_is_patchable,
check_training_step_is_patchable,
)
class TestTrainerGAIntegration(unittest.TestCase):
"""llama monkeypatch integration tests."""
def test_train_step_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_training_step_is_patchable(),
"HF transformers Trainer.training_step has changed and isn't patchable",
)
def test_model_forward_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_forward_is_patchable(),
"HF transformers LlamaForCausalLM.forward has changed and isn't patchable",
)

View File

@@ -4,7 +4,6 @@ shared fixtures for prompt strategies tests
import pytest
from datasets import Dataset
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
@@ -61,17 +60,6 @@ def fixture_basic_dataset():
@pytest.fixture(name="llama3_tokenizer")
def fixture_llama3_tokenizer():
hf_hub_download(
repo_id="NousResearch/Meta-Llama-3-8B-Instruct",
filename="special_tokens_map.json",
)
hf_hub_download(
repo_id="NousResearch/Meta-Llama-3-8B-Instruct",
filename="tokenizer_config.json",
)
hf_hub_download(
repo_id="NousResearch/Meta-Llama-3-8B-Instruct", filename="tokenizer.json"
)
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
return tokenizer

View File

@@ -1,5 +1,5 @@
"""
test module for the axolotl.utils.data module
test module for the axolotl.utis.data module
"""
import unittest