Compare commits

..

25 Commits

Author SHA1 Message Date
bursteratom
60c98a4353 stuff 2024-12-13 15:44:51 -05:00
bursteratom
c760d2b815 test accelerator 2024-12-12 12:29:35 -05:00
bursteratom
2014f58181 set os environ RANK 2024-12-11 11:45:07 -05:00
bursteratom
b5f9dd44f2 set os environ RANK 2024-12-11 11:40:20 -05:00
bursteratom
b17b1aada7 initialise process group for tp 2024-12-11 11:37:21 -05:00
bursteratom
85381b6b15 initialise process group for tp 2024-12-11 11:35:16 -05:00
bursteratom
acde081321 test lora tp 2024-12-11 11:19:34 -05:00
bursteratom
e4c68a0cbc test lora tp 2024-12-11 11:11:52 -05:00
bursteratom
3855f5c3d3 tp example tp auto 2024-12-11 11:03:39 -05:00
bursteratom
5dd566dc63 tp example 2024-12-11 11:01:23 -05:00
bursteratom
42389c1f78 enable tensor parallel 2024-12-11 10:38:14 -05:00
Wing Lian
d009ead101 fix build w pyproject to respect insalled torch version (#2168)
* fix build w pyproject to respect insalled torch version

* include in manifest

* disable duplicate code check for now

* move parser so it can be found

* add checks for correct pytorch version so this doesn't slip by again
2024-12-10 16:25:25 -05:00
Wing Lian
6aa31b44c6 make sure to checkout tag before creating release (#2164)
Some checks failed
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl (mamba-ssm, 121, 12.1.1, 3.10, 2.3.1) (push) Has been cancelled
ci-cd / build-axolotl (mamba-ssm, 121, 12.1.1, true, 3.11, 2.3.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 121, 12.1.1, 3.10, 2.3.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 121, 12.1.1, true, 3.11, 2.3.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 121, 12.1.1, 3.11, 2.3.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2024-12-09 14:20:16 -05:00
Wing Lian
9001859b0b fix release command (#2163) [skip ci] 2024-12-09 14:12:45 -05:00
Wing Lian
34d3c8dcfb [docs] Update README Quickstart to use CLI (#2137)
* update quickstart for new CLI

* add blurb about bleeding edge builds

* missed a yaml reference

* prefer lora over qlora for examples

* fix commands for parity with previous instructions

* consistency on pip/pip3 install

* one more parity pip=>pip3

* remove extraneous options in example yaml

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* update copy

* update badges and for discord and socials in readme

* Fix a few broken links

* bump version to 0.6.0 for release

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-12-09 14:03:19 -05:00
Wing Lian
ab4b32187d need to update deepspeed version in extras too (#2161) [skip ci]
* need to update deepspeed version in extras too

* fix patch import

* fix monkeypatch reloading in tests and deepspeed patch

* remove duplicated functionality fixture

* reset LlamaForCausalLM too in fixtures for cce patch

* reset llama attn too

* disable xformers patch for cce

* skip problematic test on low usage functionality
2024-12-09 14:01:44 -05:00
NanoCode012
5d6b088997 fix: chat_template masking due to truncation, consolidate turn build and keys within field (#2123) [skip ci]
* fix: chat_template masking due to truncation, consolidate turn build and keys within field

* fix: revert roles change

* fix: handling of training and training_detail

* fix: do not skip setting eos mask even if failed finding turn boundary

* fix: truncate reward modelling outputs
2024-12-09 13:49:38 -05:00
Wing Lian
3862267040 don't add dataset tags if empty due to all local data paths (#2162) [skip ci] 2024-12-09 13:49:18 -05:00
NanoCode012
c78de6f214 feat: add kto example (#2158) [skip ci] 2024-12-09 08:17:27 -05:00
Wing Lian
b1e8286c57 add missing __init__ to optimizers path (#2160) [skip ci] 2024-12-09 08:17:08 -05:00
Wing Lian
40907c6887 upgrade deepspeed to 0.16.1 (#2157) 2024-12-09 07:25:10 -05:00
NanoCode012
6a342feda2 fix: duplicate mlflow logging (#2109) [skip ci] 2024-12-09 07:24:48 -05:00
Wing Lian
0c25bc07a2 use manual version for now (#2156) 2024-12-08 21:09:12 -05:00
Sunny Liu
343a4d8855 Fixing issue#2134 Axolotl Crashes At The End Of Training If Base Model Is Local (#2140) 2024-12-08 16:39:05 -05:00
Wing Lian
393853751e add additional fft deepspeed variants (#2153) [skip ci] 2024-12-08 16:38:47 -05:00
38 changed files with 1025 additions and 213 deletions

View File

@@ -13,10 +13,13 @@ jobs:
permissions: permissions:
contents: write contents: write
steps: steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Create release - name: Create release
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: gh release create "$GITHUB_REF_NAME" # GITHUB_REF_NAME is the tag name in `on.push.tags` workflows run: gh release create "$GITHUB_REF_NAME" --generate-notes
pypi-publish: pypi-publish:
name: Upload release to PyPI name: Upload release to PyPI
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -38,7 +41,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 install wheel packaging pip3 install wheel packaging
pip3 install -e . pip3 install --no-build-isolation -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Extract tag name - name: Extract tag name

View File

@@ -60,11 +60,15 @@ jobs:
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging pip3 install --upgrade packaging
pip3 install -U -e . pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed - name: Ensure axolotl CLI was installed
run: | run: |
axolotl --help axolotl --help

View File

@@ -78,19 +78,23 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 show torch pip3 show torch
pip3 install -U -e . pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed - name: Ensure axolotl CLI was installed
run: | run: |
axolotl --help axolotl --help
- name: Run tests - name: Run tests
run: | run: |
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/ pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/ pytest -v tests/patched/
- name: cleanup pip cache - name: cleanup pip cache
run: | run: |
@@ -120,7 +124,7 @@ jobs:
- name: upgrade pip - name: upgrade pip
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel pip3 install --upgrade packaging setuptools setuptools_scm build wheel
- name: Install PyTorch - name: Install PyTorch
run: | run: |
@@ -129,20 +133,24 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 show torch pip3 show torch
python3 setup.py sdist python -m build --no-isolation --sdist
pip3 install dist/axolotl*.tar.gz pip3 install --no-build-isolation dist/axolotl*.tar.gz
python scripts/unsloth_install.py | sh python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed - name: Ensure axolotl CLI was installed
run: | run: |
axolotl --help axolotl --help
- name: Run tests - name: Run tests
run: | run: |
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/ pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/ pytest -v tests/patched/
- name: cleanup pip cache - name: cleanup pip cache
run: | run: |

View File

@@ -1,4 +1,5 @@
include requirements.txt include requirements.txt
include README.md include README.md
include LICENSE include LICENSE
include src/setuptools_axolotl_dynamic_dependencies.py
recursive-include axolotl *.py recursive-include axolotl *.py

104
README.md
View File

@@ -10,9 +10,13 @@
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License"> <img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests"> <img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests">
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a> <a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a>
<br/>
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors"><img src="https://img.shields.io/github/contributors-anon/axolotl-ai-cloud/axolotl?color=yellow&style=flat-square" alt="contributors" style="height: 20px;"></a>
<img src="https://img.shields.io/github/stars/axolotl-ai-cloud/axolotl" alt="GitHub Repo stars"> <img src="https://img.shields.io/github/stars/axolotl-ai-cloud/axolotl" alt="GitHub Repo stars">
</p> <br/>
<p align="center"> <a href="https://discord.com/invite/HhrNrHJPRb"><img src="https://img.shields.io/badge/discord-7289da.svg?style=flat-square&logo=discord" alt="discord" style="height: 20px;"></a>
<a href="https://twitter.com/axolotl_ai"><img src="https://img.shields.io/twitter/follow/axolotl_ai?style=social" alt="twitter" style="height: 20px;"></a>
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly"> <img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests"> <img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
</p> </p>
@@ -42,7 +46,8 @@ Features:
- [Axolotl](#axolotl) - [Axolotl](#axolotl)
- [Table of Contents](#table-of-contents) - [Table of Contents](#table-of-contents)
- [Quickstart ⚡](#quickstart-) - [Quickstart ⚡](#quickstart-)
- [Usage](#usage) - [Edge Builds](#edge-builds-)
- [Axolotl CLI Usage](#axolotl-cli-usage)
- [Badge ❤🏷️](#badge-) - [Badge ❤🏷️](#badge-)
- [Contributing 🤝](#contributing-) - [Contributing 🤝](#contributing-)
- [Sponsors 🤝❤](#sponsors-) - [Sponsors 🤝❤](#sponsors-)
@@ -107,58 +112,49 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
**Requirements**: *Nvidia* GPU (Ampere architecture or newer for `bf16` and Flash Attention) or *AMD* GPU, Python >=3.10 and PyTorch >=2.3.1. **Requirements**: *Nvidia* GPU (Ampere architecture or newer for `bf16` and Flash Attention) or *AMD* GPU, Python >=3.10 and PyTorch >=2.3.1.
```bash ```bash
git clone https://github.com/axolotl-ai-cloud/axolotl pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# download examples and optionally deepspeed configs to the local path
axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL
# finetune using lora
axolotl train examples/llama-3/lora-1b.yml
```
### Edge Builds 🏎️
If you're looking for the latest features and updates between releases, you'll need to install
from source.
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl cd axolotl
pip3 install packaging ninja pip3 install packaging ninja
pip3 install -e '.[flash-attn,deepspeed]' pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
``` ```
### Usage ### Axolotl CLI Usage
```bash We now support a new, more streamlined CLI using [click](https://click.palletsprojects.com/en/stable/).
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml
# finetune lora
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./outputs/lora-out"
# gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
```
### Axolotl CLI
If you've installed this package using `pip` from source, we now support a new, more
streamlined CLI using [click](https://click.palletsprojects.com/en/stable/). Rewriting
the above commands:
```bash ```bash
# preprocess datasets - optional but recommended # preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/openllama-3b/lora.yml CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/llama-3/lora-1b.yml
# finetune lora # finetune lora
axolotl train examples/openllama-3b/lora.yml axolotl train examples/llama-3/lora-1b.yml
# inference # inference
axolotl inference examples/openllama-3b/lora.yml \ axolotl inference examples/llama-3/lora-1b.yml \
--lora-model-dir="./outputs/lora-out" --lora-model-dir="./outputs/lora-out"
# gradio # gradio
axolotl inference examples/openllama-3b/lora.yml \ axolotl inference examples/llama-3/lora-1b.yml \
--lora-model-dir="./outputs/lora-out" --gradio --lora-model-dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL # remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml # Note: the yaml config must directly link to the **raw** yaml
axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml
``` ```
We've also added a new command for fetching `examples` and `deepspeed_configs` to your We've also added a new command for fetching `examples` and `deepspeed_configs` to your
@@ -175,6 +171,36 @@ axolotl fetch deepspeed_configs
axolotl fetch examples --dest path/to/folder axolotl fetch examples --dest path/to/folder
``` ```
### Legacy Usage
<details>
<summary>Click to Expand</summary>
While the Axolotl CLI is the preferred method for interacting with axolotl, we
still support the legacy `-m axolotl.cli.*` usage.
```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/llama-3/lora-1b.yml
# finetune lora
accelerate launch -m axolotl.cli.train examples/llama-3/lora-1b.yml
# inference
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
--lora_model_dir="./outputs/lora-out"
# gradio
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
--lora_model_dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml
```
</details>
## Badge ❤🏷️ ## Badge ❤🏷️
Building something cool with Axolotl? Consider adding a badge to your model card. Building something cool with Axolotl? Consider adding a badge to your model card.
@@ -294,7 +320,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
3. Install Axolotl along with python dependencies 3. Install Axolotl along with python dependencies
```bash ```bash
pip3 install packaging pip3 install packaging
pip3 install -e '.[flash-attn,deepspeed]' pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
``` ```
4. (Optional) Login to Huggingface to use gated models/datasets. 4. (Optional) Login to Huggingface to use gated models/datasets.
```bash ```bash
@@ -373,7 +399,7 @@ Please use WSL or Docker!
Use the below instead of the install method in QuickStart. Use the below instead of the install method in QuickStart.
``` ```
pip3 install -e '.' pip3 install --no-build-isolation -e '.'
``` ```
More info: [mac.md](/docs/mac.qmd) More info: [mac.md](/docs/mac.qmd)

View File

@@ -31,9 +31,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi fi
RUN python scripts/unsloth_install.py | sh RUN python scripts/unsloth_install.py | sh

View File

@@ -1,7 +1,10 @@
#!/bin/bash #!/bin/bash
set -e set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/ # pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi fi
RUN python scripts/unsloth_install.py | sh RUN python scripts/unsloth_install.py | sh

View File

@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
fi fi
# So we can test the Docker image # So we can test the Docker image

View File

@@ -52,7 +52,7 @@ export GPU_ARCHS="gfx90a"
cd flash-attention cd flash-attention
export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])') export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])')
patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch
pip install . pip install --no-build-isolation .
``` ```
### 6. Install Axolotl ### 6. Install Axolotl
@@ -63,7 +63,7 @@ Clone and install Axolotl:
git clone https://github.com/axolotl-ai-cloud/axolotl git clone https://github.com/axolotl-ai-cloud/axolotl
cd axolotl cd axolotl
pip install packaging ninja pip install packaging ninja
pip install -e . pip install --no-build-isolation -e .
``` ```
### 7. Apply xformers Workaround ### 7. Apply xformers Workaround

View File

@@ -71,7 +71,7 @@ Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/us
```bash ```bash
pip3 install packaging pip3 install packaging
pip3 install -e '.[flash-attn,deepspeed]' pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
``` ```
#### Remote Hosts #### Remote Hosts
@@ -212,7 +212,7 @@ You will now be in the container. Next, perform an editable install of Axolotl:
```bash ```bash
pip3 install packaging pip3 install packaging
pip3 install -e '.[flash-attn,deepspeed]' pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
``` ```
### Attach To Container ### Attach To Container

View File

@@ -52,6 +52,26 @@ datasets:
type: chat_template.argilla type: chat_template.argilla
``` ```
#### KTO
```yaml
rl: kto
rl_beta: 0.5
kto_desirable_weight: 0.2
remove_unused_columns: false
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
type: llama3.ultra
split: train
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
```
#### Using local dataset files #### Using local dataset files
```yaml ```yaml
datasets: datasets:

View File

@@ -24,7 +24,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"!pip install axolotl[deepspeed]" "!pip install --no-build-isolation axolotl[deepspeed]"
] ]
}, },
{ {

View File

@@ -0,0 +1,58 @@
base_model: NousResearch/Meta-Llama-3.1-8B
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
tensor_parallel: 'auto'
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -0,0 +1,74 @@
base_model: NousResearch/Llama-3.2-1B
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -0,0 +1,73 @@
base_model: NousResearch/Meta-Llama-3.1-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_modules_to_save:
- embed_tokens
- lm_head
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
tensor_parallel: 'auto'
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -0,0 +1,75 @@
base_model: meta-llama/Llama-3.2-1B
load_in_8bit: false
load_in_4bit: true
strict: false
rl: kto
rl_beta: 0.5
kto_desirable_weight: 0.2
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
type: llama3.ultra
split: train
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/qlora-out
remove_unused_columns: false
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: false # not supported with kto
eval_sample_packing: false
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 64
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 20
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Llama-3.2-1B base_model: NousResearch/Llama-3.2-1B
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true
@@ -22,7 +22,6 @@ pad_to_sequence_len: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
lora_target_modules: lora_target_modules:
- gate_proj - gate_proj

View File

@@ -17,3 +17,10 @@ Homepage = "https://axolotl-ai-cloud.github.io/axolotl/"
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git" Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
[tool.setuptools_scm] [tool.setuptools_scm]
[tool.setuptools]
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
include-package-data = true
[tool.setuptools.cmdclass]
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"

View File

@@ -1,22 +1,30 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.0
triton>=2.3.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.0.post2
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.4.2
# END section
packaging==23.2 packaging==23.2
peft==0.14.0 peft==0.14.0
transformers>=4.46.3 transformers>=4.46.3
tokenizers>=0.20.1 tokenizers>=0.20.1
bitsandbytes==0.45.0
accelerate==1.2.0 accelerate==1.2.0
datasets==3.1.0 datasets==3.1.0
deepspeed==0.15.4 deepspeed==0.16.1
pydantic==2.6.3 pydantic==2.6.3
addict addict
fire fire
PyYAML>=6.0 PyYAML>=6.0
requests requests
flash-attn==2.7.0.post2
sentencepiece sentencepiece
wandb wandb
einops einops
xformers>=0.0.23.post1
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer
colorama colorama
@@ -31,11 +39,6 @@ art
gradio==3.50.2 gradio==3.50.2
tensorboard tensorboard
python-dotenv==1.0.1 python-dotenv==1.0.1
autoawq==0.2.7.post3
triton>=2.3.0
liger-kernel==0.4.2
mamba-ssm==1.2.0.post1
# remote filesystems # remote filesystems
s3fs>=2024.5.0 s3fs>=2024.5.0

View File

@@ -13,5 +13,5 @@ cd /workspace
rm -rf /workspace/axolotl rm -rf /workspace/axolotl
git clone https://github.com/axolotl-ai-cloud/axolotl.git git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl cd axolotl
pip install --no-deps -e . pip install --no-build-isolation --no-deps -e .
``` ```

View File

@@ -1,7 +1,10 @@
"""setup.py for axolotl""" """setup.py for axolotl"""
import ast
import os
import platform import platform
import re import re
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from setuptools import find_packages, setup from setuptools import find_packages, setup
@@ -90,9 +93,24 @@ def parse_requirements():
return _install_requires, _dependency_links return _install_requires, _dependency_links
def get_package_version():
with open(
Path(os.path.dirname(os.path.abspath(__file__)))
/ "src"
/ "axolotl"
/ "__init__.py",
"r",
encoding="utf-8",
) as fin:
version_match = re.search(r"^__version__\s*=\s*(.*)$", fin.read(), re.MULTILINE)
version_ = ast.literal_eval(version_match.group(1))
return version_
install_requires, dependency_links = parse_requirements() install_requires, dependency_links = parse_requirements()
setup( setup(
version=get_package_version(),
package_dir={"": "src"}, package_dir={"": "src"},
packages=find_packages("src"), packages=find_packages("src"),
install_requires=install_requires, install_requires=install_requires,
@@ -107,7 +125,7 @@ setup(
"flash-attn==2.7.0.post2", "flash-attn==2.7.0.post2",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed==0.15.4", "deepspeed==0.16.1",
"deepspeed-kernels", "deepspeed-kernels",
], ],
"mamba-ssm": [ "mamba-ssm": [

View File

@@ -1,8 +1,3 @@
"""Axolotl - Train and fine-tune large language models""" """Axolotl - Train and fine-tune large language models"""
try: __version__ = "0.6.0"
from importlib.metadata import version
__version__ = version("axolotl")
except ImportError:
__version__ = "unknown"

View File

@@ -1319,6 +1319,10 @@ class TrainerBuilderBase(abc.ABC):
if hasattr(model, "add_model_tags"): if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"]) model.add_model_tags(["axolotl"])
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
os.environ["ACCELERATE_USE_TP"] = "true"
# self.model =
@property @property
def model_ref(self): def model_ref(self):
return self._model_ref return self._model_ref
@@ -1368,8 +1372,6 @@ class TrainerBuilderBase(abc.ABC):
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
) )
if self.cfg.use_mlflow and is_mlflow_available(): if self.cfg.use_mlflow and is_mlflow_available():
from transformers.integrations.integration_utils import MLflowCallback
from axolotl.utils.callbacks.mlflow_ import ( from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback, SaveAxolotlConfigtoMlflowCallback,
) )
@@ -1377,7 +1379,6 @@ class TrainerBuilderBase(abc.ABC):
callbacks.extend( callbacks.extend(
[ [
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path), SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
MLflowCallback,
] ]
) )
if self.cfg.use_comet and is_comet_available(): if self.cfg.use_comet and is_comet_available():

View File

@@ -204,3 +204,87 @@ def patch_forward_for_ga():
LlamaForCausalLM.forward = ( # pylint: disable=protected-access LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821 _fixed_forward # pylint: disable=undefined-variable # noqa: F821
) )
ORIGINAL_TRAINER_CODE = """
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
PATCHED_TRAINER_CODE = """
disable_deepspeed_no_sync = (
self.accelerator.distributed_type == DistributedType.DEEPSPEED
# and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients()
)
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
training_loop = get_training_loop_code()
training_loop, _ = detab_code(training_loop)
return ORIGINAL_TRAINER_CODE in training_loop
def patch_training_loop_for_deepspeed_0_16_x():
"""
monkeypatch for fixing the training loop for deepspeed GA
see https://github.com/huggingface/transformers/pull/35157
"""
try:
training_loop = get_training_loop_code()
except OSError:
return
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
training_loop
)
training_loop, _ = detab_code(training_loop)
if ORIGINAL_TRAINER_CODE not in training_loop:
return
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
training_loop = training_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -28,6 +28,8 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
:return: :return:
""" """
max_length = self.prompter.max_length
self.messages = "chosen_messages" self.messages = "chosen_messages"
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
prompt[self.messages] = [] prompt[self.messages] = []
@@ -39,6 +41,16 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]}) prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt) chosen_tokenized = super().tokenize_prompt(prompt)
if len(chosen_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
)
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
chosen_tokenized["attention_mask"] = chosen_tokenized["attention_mask"][
:max_length
]
self.messages = "rejected_messages" self.messages = "rejected_messages"
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
prompt[self.messages] = [] prompt[self.messages] = []
@@ -52,6 +64,18 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
) )
rejected_tokenized = super().tokenize_prompt(prompt) rejected_tokenized = super().tokenize_prompt(prompt)
if len(rejected_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
)
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
:max_length
]
rejected_tokenized["attention_mask"] = rejected_tokenized["attention_mask"][
:max_length
]
return { return {
"input_ids_chosen": chosen_tokenized["input_ids"], "input_ids_chosen": chosen_tokenized["input_ids"],
"attention_mask_chosen": chosen_tokenized["attention_mask"], "attention_mask_chosen": chosen_tokenized["attention_mask"],
@@ -80,9 +104,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
"roles": ds_cfg.get("roles"), "roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False), "drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit. # we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1 "max_length": (
if not cfg.reward_model cfg.sequence_len + 1 if not cfg.reward_model else cfg.sequence_len
else cfg.sequence_len, ),
} }
strategy_params = { strategy_params = {

View File

@@ -42,6 +42,7 @@ class ChatTemplatePrompter(Prompter):
"gpt": "assistant", "gpt": "assistant",
"system": "system", "system": "system",
} }
self.message_field_role = message_field_role self.message_field_role = message_field_role
self.message_field_content = message_field_content self.message_field_content = message_field_content
self.message_field_training = message_field_training self.message_field_training = message_field_training
@@ -53,21 +54,9 @@ class ChatTemplatePrompter(Prompter):
self.drop_system_message = drop_system_message self.drop_system_message = drop_system_message
def build_prompt(self, conversation, add_generation_prompt=False, images=None): def build_prompt(self, conversation, add_generation_prompt=False, images=None):
turns = [
{
"role": self.roles[t[self.message_field_role]],
"content": t[self.message_field_content],
"training": t.get(self.message_field_training, None),
}
for t in conversation
]
if self.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
if self.processor: if self.processor:
text = self.processor.apply_chat_template( text = self.processor.apply_chat_template(
turns, conversation,
chat_template=self.chat_template, chat_template=self.chat_template,
tokenize=False, tokenize=False,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
@@ -76,8 +65,6 @@ class ChatTemplatePrompter(Prompter):
text=text, text=text,
images=images, images=images,
return_tensors="pt", return_tensors="pt",
truncation=True,
max_length=self.max_length,
) )
# workaround since processor works in batches instead of single examples # workaround since processor works in batches instead of single examples
for k, val in batch.items(): for k, val in batch.items():
@@ -88,9 +75,7 @@ class ChatTemplatePrompter(Prompter):
return batch return batch
return self.tokenizer.apply_chat_template( return self.tokenizer.apply_chat_template(
turns, conversation,
truncation=True,
max_length=self.max_length,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template, chat_template=self.chat_template,
) )
@@ -215,7 +200,14 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
train_on_eos=None, train_on_eos=None,
): ):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len) super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.roles_to_train = roles_to_train if roles_to_train is not None else []
self.roles_to_train = []
if roles_to_train:
# map roles if exist in prompter.roles else use the role as is
self.roles_to_train = [
prompter.roles.get(role, role) for role in roles_to_train
]
self.train_on_eos = train_on_eos self.train_on_eos = train_on_eos
self.images = "images" self.images = "images"
@@ -262,30 +254,28 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return tokenized_prompt return tokenized_prompt
turns = prompt[self.messages] turns = self.get_conversation_thread(prompt)
input_ids = self.prompter.build_prompt(turns) input_ids = self.prompter.build_prompt(turns)
labels = [IGNORE_TOKEN_ID] * len(input_ids) labels = [IGNORE_TOKEN_ID] * len(input_ids)
last_eos_idx = -1 last_eos_idx = -1
for index, turn in enumerate(turns): for index, turn in enumerate(turns):
role = turn.get(self.prompter.message_field_role) role = turn.get("role")
content = turn.get(self.prompter.message_field_content) content = turn.get("content")
train_turn = turn.get(self.prompter.message_field_training) train_turn = turn.get("training")
train_detail = turn.get(self.prompter.message_field_training_detail) train_detail = turn.get("training_detail")
LOG.debug( LOG.debug(
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}" f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
) )
should_train = ( should_train = None
train_turn if train_turn is not None:
if train_turn is not None should_train = train_turn
else ( elif train_detail is not None:
bool(train_detail is not None) should_train = bool(train_detail)
if train_detail is not None else:
else self.train_on_inputs or role in self.roles_to_train should_train = self.train_on_inputs or role in self.roles_to_train
)
)
LOG.debug(f"Should train: {should_train}") LOG.debug(f"Should train: {should_train}")
@@ -293,6 +283,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
conversation_ids=input_ids, turn=index, turn_content=turn conversation_ids=input_ids, turn=index, turn_content=turn
) )
if turn_start_idx == -1 or turn_end_idx == -1:
LOG.warning(f"Failed to find boundaries for turn {index}")
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}") LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
if should_train and turn_start_idx != -1 and turn_end_idx != -1: if should_train and turn_start_idx != -1 and turn_end_idx != -1:
@@ -313,7 +306,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
labels[turn_start_idx:turn_end_idx] = input_ids[ labels[turn_start_idx:turn_end_idx] = input_ids[
turn_start_idx:turn_end_idx turn_start_idx:turn_end_idx
] ]
LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}") LOG.debug(
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
)
LOG.debug(f"Labels after processing turn {index}: {labels}") LOG.debug(f"Labels after processing turn {index}: {labels}")
@@ -351,52 +346,73 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return i return i
return -1 return -1
def find_turn(self, conversation_ids, turn, turn_content): def find_turn(self, conversation_ids: list[int], turn: int, turn_content: dict):
""" """
Locate the starting and ending indices of the specified turn in a conversation. Locate the starting and ending indices of the specified turn in a conversation.
Args:
conversation_ids (list[int]): Token IDs representing the conversation.
turn (int): The turn number to locate (based on EOS tokens).
turn_content (str): String containing the content of the turn.
Returns:
tuple: (start_idx, end_idx) indices of the start and end of the turn content.
Returns (-1, -1) if the turn content is not found.
""" """
content = turn_content.get(self.prompter.message_field_content, "") content = turn_content.get("content")
content_ids = self.tokenizer.encode(content, add_special_tokens=False) content_ids = self.tokenizer.encode(content, add_special_tokens=False)
eos_token_id = self.tokenizer.eos_token_id LOG.debug(f"content_ids (length {len(content_ids)}): {content_ids}")
eos_count = 0
start_search_idx = 0
# Locate the starting index after the specified number of EOS tokens if not content_ids:
for i, token_id in enumerate(conversation_ids): LOG.warning(f"Empty content for turn {turn}")
if token_id == eos_token_id: return -1, -1
eos_count += 1
if eos_count == turn:
start_search_idx = (
i + 1
) # Start searching after the specified turn's EOS token
break
# Find the start index of the content within the conversation # For first turn, start from beginning
start_idx = -1 if turn == 0:
for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1): start_search_idx = 0
if conversation_ids[i : i + len(content_ids)] == content_ids:
start_idx = i
break
if start_idx != -1:
end_idx = start_idx + len(content_ids)
else: else:
end_idx = -1 # For subsequent turns, find the previous EOS token
eos_token_id = self.tokenizer.eos_token_id
eos_count = 0
start_search_idx = 0
return start_idx, end_idx for i, token_id in enumerate(conversation_ids):
if token_id == eos_token_id:
eos_count += 1
if eos_count == turn: # Find the nth EOS token where n = turn
start_search_idx = i + 1
break
# we can optimize this to only search for a few tokens from start_search_idx
# but it would risk missing the content if it's not found within the first few tokens or
# if start_search_idx cannot be found above.
last_index = len(conversation_ids) - len(content_ids) + 1
if last_index < start_search_idx:
LOG.warning(
f"last_index to search is less than start_search_idx for turn {turn}"
)
return -1, -1
# Search for content starting from start_search_idx
first_elem = content_ids[0]
for i in range(start_search_idx, last_index):
# Quick check of first element before doing full comparison
if conversation_ids[i] == first_elem:
# Check if the rest of the content matches
if conversation_ids[i : i + len(content_ids)] == content_ids:
LOG.debug(f"Found turn {turn} content at position {i}")
return i, i + len(content_ids)
return -1, -1
def get_conversation_thread(self, prompt): def get_conversation_thread(self, prompt):
return prompt[self.messages] turns = [
{
"role": self.prompter.roles[t[self.prompter.message_field_role]],
"content": t[self.prompter.message_field_content],
"training": t.get(self.prompter.message_field_training),
"training_detail": t.get(self.prompter.message_field_training_detail),
}
for t in prompt[self.messages]
]
if self.prompter.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
return turns
def get_images(self, prompt): def get_images(self, prompt):
return prompt.get(self.images, None) return prompt.get(self.images, None)

View File

@@ -259,14 +259,7 @@ def train(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id: if not cfg.hub_model_id:
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError
try: try:
# Check to make sure the base model is from HuggingFace not a local directory
hf_api = HfApi()
hf_api.model_info(cfg.base_model)
model_card_kwarg = { model_card_kwarg = {
"model_name": cfg.output_dir.lstrip("./") "model_name": cfg.output_dir.lstrip("./")
.encode("utf-8") .encode("utf-8")
@@ -274,16 +267,22 @@ def train(
} }
if cfg.datasets is not None: if cfg.datasets is not None:
if cfg.rl is not None or cfg.reward_model: if cfg.rl is not None or cfg.reward_model:
model_card_kwarg["dataset_name"] = [ dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
] ]
if dataset_tags:
# guard as create_model_card may fail if dataset_tags is empty list
model_card_kwarg["dataset_name"] = dataset_tags
else: else:
model_card_kwarg["dataset_tags"] = [ dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
] ]
if dataset_tags:
# guard as create_model_card may fail if dataset_tags is empty list
model_card_kwarg["dataset_tags"] = dataset_tags
trainer.create_model_card(**model_card_kwarg) trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError, RepositoryNotFoundError): except (AttributeError, UnicodeDecodeError):
pass pass
elif cfg.hub_model_id: elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated # defensively push to the hub to ensure the model card is updated

View File

@@ -393,7 +393,7 @@ class ModelInputConfig(BaseModel):
default=None, json_schema_extra={"description": "transformers processor class"} default=None, json_schema_extra={"description": "transformers processor class"}
) )
trust_remote_code: Optional[bool] = None trust_remote_code: Optional[bool] = None
tensor_parallel: Optional[Union[Literal["auto"], bool]] = "auto"
model_kwargs: Optional[Dict[str, Any]] = None model_kwargs: Optional[Dict[str, Any]] = None
@field_validator("trust_remote_code") @field_validator("trust_remote_code")
@@ -1475,6 +1475,27 @@ class AxolotlInputConfig(
return data return data
@model_validator(mode="before")
@classmethod
def check_kto_config(cls, data):
if data.get("rl") == "kto":
if data.get("sample_packing") or data.get("eval_sample_packing"):
raise ValueError("sample_packing is not supported with kto")
if data.get("remove_unused_columns") is not False:
raise ValueError("Set `remove_unused_columns: False` when using kto")
if data.get("gradient_checkpointing") and not (
data.get("gradient_checkpointing_kwargs")
and isinstance(data.get("gradient_checkpointing_kwargs"), dict)
and data["gradient_checkpointing_kwargs"].get("use_reentrant")
):
raise ValueError(
"Set `gradient_checkpointing_kwargs: {use_reentrant: true}` for when kto is enabled"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig): class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options""" """wrapper to valdiate gpu capabilities with the configured options"""

View File

@@ -386,6 +386,12 @@ class ModelLoader:
) )
patch_training_loop_for_fsdp() patch_training_loop_for_fsdp()
elif self.cfg.deepspeed and self.cfg.gradient_accumulation_steps > 1:
from axolotl.monkeypatch.trainer_grad_accum import (
patch_training_loop_for_deepspeed_0_16_x,
)
patch_training_loop_for_deepspeed_0_16_x()
if self.cfg.gradient_checkpointing == "unsloth": if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
@@ -1181,9 +1187,15 @@ class ModelLoader:
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.post_loading_set_env()
# TODO resume_from_checkpoint handling # TODO resume_from_checkpoint handling
return self.model, lora_config return self.model, lora_config
def post_loading_set_env(self):
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
os.environ["ACCELERATE_USE_TP"] = "true"
def load_model( def load_model(
cfg: DictDefault, cfg: DictDefault,

View File

View File

@@ -0,0 +1,104 @@
"""
dynamic requirements for axolotl
"""
import platform
import re
from importlib.metadata import PackageNotFoundError, version
from setuptools.command.build_py import build_py as _build_py
# pylint: disable=duplicate-code
def parse_requirements():
_install_requires = []
_dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [r.strip() for r in requirements_file.readlines()]
for line in lines:
is_extras = (
"flash-attn" in line
or "flash-attention" in line
or "deepspeed" in line
or "mamba-ssm" in line
or "lion-pytorch" in line
)
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
elif not is_extras and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version))
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
try:
torch_version = version("torch")
except PackageNotFoundError:
torch_version = "2.5.1"
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if version_match:
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = (
int(patch) if patch is not None else 0
) # Default patch to 0 if not present
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers==0.0.28.post3")
_install_requires.pop(_install_requires.index(autoawq_version))
elif (major, minor) >= (2, 4):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1")
else:
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1")
except PackageNotFoundError:
pass
return _install_requires, _dependency_links
class BuildPyCommand(_build_py):
"""
custom build_py command to parse dynamic requirements
"""
def finalize_options(self):
super().finalize_options()
install_requires, _ = parse_requirements()
self.distribution.install_requires = install_requires

View File

@@ -0,0 +1,10 @@
"""pytest tests for axolotl CLI --version"""
from axolotl.cli.main import cli
def test_print_version(cli_runner):
"""Test that version is printed when --version is used."""
result = cli_runner.invoke(cli, ["--version"])
assert result.exit_code == 0
assert "axolotl, version " in result.output

View File

@@ -120,9 +120,15 @@ def temp_dir():
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches(): def cleanup_monkeypatches():
from transformers import Trainer from transformers import Trainer
from transformers.models.llama.modeling_llama import LlamaFlashAttention2 from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaFlashAttention2,
LlamaForCausalLM,
)
original_fa2_forward = LlamaFlashAttention2.forward original_fa2_forward = LlamaFlashAttention2.forward
original_llama_attn_forward = LlamaAttention.forward
original_llama_forward = LlamaForCausalLM.forward
original_trainer_inner_training_loop = ( original_trainer_inner_training_loop = (
Trainer._inner_training_loop # pylint: disable=protected-access Trainer._inner_training_loop # pylint: disable=protected-access
) )
@@ -131,6 +137,8 @@ def cleanup_monkeypatches():
yield yield
# Reset LlamaFlashAttention2 forward # Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward LlamaFlashAttention2.forward = original_fa2_forward
LlamaAttention.forward = original_llama_attn_forward
LlamaForCausalLM.forward = original_llama_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access Trainer._inner_training_loop = ( # pylint: disable=protected-access
original_trainer_inner_training_loop original_trainer_inner_training_loop
) )
@@ -138,16 +146,25 @@ def cleanup_monkeypatches():
# Reset other known monkeypatches # Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [ modules_to_reset: list[tuple[str, list[str]]] = [
("transformers",), ("transformers.models.llama",),
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]), (
("transformers.trainer", ["Trainer"]), "transformers.models.llama.modeling_llama",
["LlamaFlashAttention2", "LlamaAttention"],
),
("transformers.trainer",),
("transformers", ["Trainer"]),
("transformers.loss.loss_utils",), ("transformers.loss.loss_utils",),
] ]
for module_name_tuple in modules_to_reset: for module_name_tuple in modules_to_reset:
module_name = module_name_tuple[0] module_name = module_name_tuple[0]
module = importlib.import_module(module_name)
sys.modules[module_name] = module spec = importlib.util.spec_from_file_location(
importlib.reload(sys.modules[module_name]) module_name, sys.modules[module_name].__file__
)
sys.modules[module_name] = importlib.util.module_from_spec(spec)
spec.loader.exec_module(sys.modules[module_name])
sys.modules[module_name] = importlib.reload(sys.modules[module_name])
if len(module_name_tuple) > 1: if len(module_name_tuple) > 1:
module_globals = module_name_tuple[1] module_globals = module_name_tuple[1]
for module_global in module_globals: for module_global in module_globals:

View File

@@ -71,7 +71,11 @@ class TestCutCrossEntropyIntegration:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"attention_type", "attention_type",
["flash_attention", "sdp_attention", "xformers_attention"], [
"flash_attention",
"sdp_attention",
# "xformers_attention",
],
) )
def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type): def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type):
cfg = DictDefault( cfg = DictDefault(

View File

@@ -9,6 +9,7 @@ from pathlib import Path
import pytest import pytest
import yaml import yaml
from accelerate.test_utils import execute_subprocess_async from accelerate.test_utils import execute_subprocess_async
from e2e.utils import check_tensorboard
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers.testing_utils import get_torch_dist_unique_port from transformers.testing_utils import get_torch_dist_unique_port
@@ -53,7 +54,7 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 15, "max_steps": 2,
"micro_batch_size": 4, "micro_batch_size": 4,
"gradient_accumulation_steps": 4, "gradient_accumulation_steps": 4,
"output_dir": temp_dir, "output_dir": temp_dir,
@@ -61,6 +62,7 @@ class TestMultiGPULlama:
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"use_tensorboard": True,
} }
) )
@@ -83,9 +85,13 @@ class TestMultiGPULlama:
] ]
) )
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"gradient_accumulation_steps", "gradient_accumulation_steps",
[1, 4], [1, 2],
) )
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps): def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -112,14 +118,15 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 15, "max_steps": 2,
"micro_batch_size": 4, "micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps, "gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"use_tensorboard": True,
} }
) )
@@ -142,6 +149,10 @@ class TestMultiGPULlama:
] ]
) )
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
def test_dpo_lora_ddp(self, temp_dir): def test_dpo_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
@@ -180,7 +191,7 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 15, "max_steps": 2,
"micro_batch_size": 4, "micro_batch_size": 4,
"gradient_accumulation_steps": 4, "gradient_accumulation_steps": 4,
"output_dir": temp_dir, "output_dir": temp_dir,
@@ -189,6 +200,7 @@ class TestMultiGPULlama:
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"use_tensorboard": True,
} }
) )
@@ -211,6 +223,10 @@ class TestMultiGPULlama:
] ]
) )
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
def test_dpo_qlora_ddp(self, temp_dir): def test_dpo_qlora_ddp(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
@@ -249,8 +265,8 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 15, "max_steps": 2,
"micro_batch_size": 4, "micro_batch_size": 2,
"gradient_accumulation_steps": 4, "gradient_accumulation_steps": 4,
"output_dir": temp_dir, "output_dir": temp_dir,
"warmup_steps": 0, "warmup_steps": 0,
@@ -258,6 +274,7 @@ class TestMultiGPULlama:
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"use_tensorboard": True,
} }
) )
@@ -280,9 +297,13 @@ class TestMultiGPULlama:
] ]
) )
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"gradient_accumulation_steps", "gradient_accumulation_steps",
[1, 4], [1, 2],
) )
def test_fsdp(self, temp_dir, gradient_accumulation_steps): def test_fsdp(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -301,8 +322,8 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 10, "max_steps": 2,
"micro_batch_size": 4, "micro_batch_size": 2,
"gradient_accumulation_steps": gradient_accumulation_steps, "gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
@@ -323,6 +344,7 @@ class TestMultiGPULlama:
"fsdp_state_dict_type": "FULL_STATE_DICT", "fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
}, },
"use_tensorboard": True,
} }
) )
@@ -345,6 +367,10 @@ class TestMultiGPULlama:
] ]
) )
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fsdp_state_dict_type", "fsdp_state_dict_type",
["FULL_STATE_DICT", "SHARDED_STATE_DICT"], ["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
@@ -368,7 +394,7 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 15, "max_steps": 2,
"micro_batch_size": 4, "micro_batch_size": 4,
"gradient_accumulation_steps": 4, "gradient_accumulation_steps": 4,
"output_dir": temp_dir, "output_dir": temp_dir,
@@ -390,6 +416,7 @@ class TestMultiGPULlama:
"fsdp_state_dict_type": fsdp_state_dict_type, "fsdp_state_dict_type": fsdp_state_dict_type,
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
}, },
"use_tensorboard": True,
} }
) )
@@ -412,6 +439,10 @@ class TestMultiGPULlama:
] ]
) )
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
def test_fsdp_qlora_prequant_packed(self, temp_dir): def test_fsdp_qlora_prequant_packed(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
@@ -444,7 +475,7 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 15, "max_steps": 2,
"micro_batch_size": 4, "micro_batch_size": 4,
"gradient_accumulation_steps": 4, "gradient_accumulation_steps": 4,
"output_dir": temp_dir, "output_dir": temp_dir,
@@ -466,6 +497,7 @@ class TestMultiGPULlama:
"fsdp_state_dict_type": "SHARDED_STATE_DICT", "fsdp_state_dict_type": "SHARDED_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
}, },
"use_tensorboard": True,
} }
) )
@@ -488,12 +520,41 @@ class TestMultiGPULlama:
] ]
) )
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"gradient_accumulation_steps", "gradient_accumulation_steps",
[1, 4], [1, 2],
) )
def test_ds_zero3_packed(self, temp_dir, gradient_accumulation_steps): @pytest.mark.parametrize(
"deepspeed",
[
"deepspeed_configs/zero3_bf16.json",
"deepspeed_configs/zero3_bf16_cpuoffload_all.json",
# "deepspeed_configs/zero3_bf16_cpuoffload_params.json",
],
)
@pytest.mark.parametrize(
"qlora",
[True, False],
)
def test_ds_zero3_packed(
self, temp_dir, gradient_accumulation_steps, deepspeed, qlora
):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
if qlora:
adapter = {
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"load_in_4bit": True,
}
else:
adapter = {}
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -511,15 +572,17 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 15, "max_steps": 2,
"micro_batch_size": 4, "micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps, "gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"), "deepspeed": str(AXOLOTL_ROOT / deepspeed),
"use_tensorboard": True,
**adapter,
} }
) )
@@ -542,19 +605,35 @@ class TestMultiGPULlama:
] ]
) )
def test_ds_zero3_qlora_packed(self, temp_dir): check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],
)
@pytest.mark.parametrize(
"qlora",
[True, False],
)
def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps, qlora):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( if qlora:
{ adapter = {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_4bit": True,
"adapter": "qlora", "adapter": "qlora",
"lora_r": 8, "lora_r": 8,
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"load_in_4bit": True,
}
else:
adapter = {}
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True, "sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"sequence_len": 2048, "sequence_len": 2048,
"val_set_size": 0.05, "val_set_size": 0.05,
@@ -568,15 +647,17 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 15, "max_steps": 2,
"micro_batch_size": 4, "micro_batch_size": 1,
"gradient_accumulation_steps": 4, "gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.0001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"), "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
"use_tensorboard": True,
**adapter,
} }
) )
@@ -598,3 +679,82 @@ class TestMultiGPULlama:
str(Path(temp_dir) / "config.yaml"), str(Path(temp_dir) / "config.yaml"),
] ]
) )
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],
)
@pytest.mark.parametrize(
"qlora",
[True, False],
)
def test_ds_zero1_packed(self, temp_dir, gradient_accumulation_steps, qlora):
# pylint: disable=duplicate-code
if qlora:
adapter = {
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"load_in_4bit": True,
}
else:
adapter = {}
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"use_tensorboard": True,
**adapter,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import logging import logging
import os import os
from importlib import reload
from pathlib import Path from pathlib import Path
import pytest import pytest
@@ -22,14 +21,6 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@pytest.fixture(autouse=True)
def reload_transformers():
import transformers.models.llama.modeling_llama
yield
reload(transformers.models.llama.modeling_llama)
class TestFAXentropyLlama: class TestFAXentropyLlama:
""" """
Test case for Llama models using LoRA w multipack Test case for Llama models using LoRA w multipack

View File

@@ -7,6 +7,7 @@ import os
import unittest import unittest
from pathlib import Path from pathlib import Path
import pytest
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets from axolotl.cli import load_datasets
@@ -21,6 +22,7 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@pytest.mark.skip("FIXME, mostly underused functionality")
class TestFusedLlama(unittest.TestCase): class TestFusedLlama(unittest.TestCase):
""" """
Test case for Llama models using Fused layers Test case for Llama models using Fused layers