Compare commits

...

29 Commits

Author SHA1 Message Date
bursteratom
60c98a4353 stuff 2024-12-13 15:44:51 -05:00
bursteratom
c760d2b815 test accelerator 2024-12-12 12:29:35 -05:00
bursteratom
2014f58181 set os environ RANK 2024-12-11 11:45:07 -05:00
bursteratom
b5f9dd44f2 set os environ RANK 2024-12-11 11:40:20 -05:00
bursteratom
b17b1aada7 initialise process group for tp 2024-12-11 11:37:21 -05:00
bursteratom
85381b6b15 initialise process group for tp 2024-12-11 11:35:16 -05:00
bursteratom
acde081321 test lora tp 2024-12-11 11:19:34 -05:00
bursteratom
e4c68a0cbc test lora tp 2024-12-11 11:11:52 -05:00
bursteratom
3855f5c3d3 tp example tp auto 2024-12-11 11:03:39 -05:00
bursteratom
5dd566dc63 tp example 2024-12-11 11:01:23 -05:00
bursteratom
42389c1f78 enable tensor parallel 2024-12-11 10:38:14 -05:00
Wing Lian
d009ead101 fix build w pyproject to respect insalled torch version (#2168)
* fix build w pyproject to respect insalled torch version

* include in manifest

* disable duplicate code check for now

* move parser so it can be found

* add checks for correct pytorch version so this doesn't slip by again
2024-12-10 16:25:25 -05:00
Wing Lian
6aa31b44c6 make sure to checkout tag before creating release (#2164)
Some checks failed
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl (mamba-ssm, 121, 12.1.1, 3.10, 2.3.1) (push) Has been cancelled
ci-cd / build-axolotl (mamba-ssm, 121, 12.1.1, true, 3.11, 2.3.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 121, 12.1.1, 3.10, 2.3.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 121, 12.1.1, true, 3.11, 2.3.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 121, 12.1.1, 3.11, 2.3.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2024-12-09 14:20:16 -05:00
Wing Lian
9001859b0b fix release command (#2163) [skip ci] 2024-12-09 14:12:45 -05:00
Wing Lian
34d3c8dcfb [docs] Update README Quickstart to use CLI (#2137)
* update quickstart for new CLI

* add blurb about bleeding edge builds

* missed a yaml reference

* prefer lora over qlora for examples

* fix commands for parity with previous instructions

* consistency on pip/pip3 install

* one more parity pip=>pip3

* remove extraneous options in example yaml

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* update copy

* update badges and for discord and socials in readme

* Fix a few broken links

* bump version to 0.6.0 for release

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-12-09 14:03:19 -05:00
Wing Lian
ab4b32187d need to update deepspeed version in extras too (#2161) [skip ci]
* need to update deepspeed version in extras too

* fix patch import

* fix monkeypatch reloading in tests and deepspeed patch

* remove duplicated functionality fixture

* reset LlamaForCausalLM too in fixtures for cce patch

* reset llama attn too

* disable xformers patch for cce

* skip problematic test on low usage functionality
2024-12-09 14:01:44 -05:00
NanoCode012
5d6b088997 fix: chat_template masking due to truncation, consolidate turn build and keys within field (#2123) [skip ci]
* fix: chat_template masking due to truncation, consolidate turn build and keys within field

* fix: revert roles change

* fix: handling of training and training_detail

* fix: do not skip setting eos mask even if failed finding turn boundary

* fix: truncate reward modelling outputs
2024-12-09 13:49:38 -05:00
Wing Lian
3862267040 don't add dataset tags if empty due to all local data paths (#2162) [skip ci] 2024-12-09 13:49:18 -05:00
NanoCode012
c78de6f214 feat: add kto example (#2158) [skip ci] 2024-12-09 08:17:27 -05:00
Wing Lian
b1e8286c57 add missing __init__ to optimizers path (#2160) [skip ci] 2024-12-09 08:17:08 -05:00
Wing Lian
40907c6887 upgrade deepspeed to 0.16.1 (#2157) 2024-12-09 07:25:10 -05:00
NanoCode012
6a342feda2 fix: duplicate mlflow logging (#2109) [skip ci] 2024-12-09 07:24:48 -05:00
Wing Lian
0c25bc07a2 use manual version for now (#2156) 2024-12-08 21:09:12 -05:00
Sunny Liu
343a4d8855 Fixing issue#2134 Axolotl Crashes At The End Of Training If Base Model Is Local (#2140) 2024-12-08 16:39:05 -05:00
Wing Lian
393853751e add additional fft deepspeed variants (#2153) [skip ci] 2024-12-08 16:38:47 -05:00
Wing Lian
1302e31049 Transformers version flexibility and FSDP optimizer patch (#2155)
* allow flexibility in transformers version for FSDP

* more flexibility with dev versions of 4.47.0.dev0

* add patch for fsdp

* fix typo

* correct fn name

* stray character

* fix patch

* reset Trainer too

* also reset Trainer.training_step

* allow tests/patched to run more than one process on e2e runner

* skip tests/patched in e2e for now since it's run in regular pytest
2024-12-08 14:50:40 -05:00
Wing Lian
be5f554a62 bump autoawq to 0.2.7.post3 (#2150) 2024-12-07 22:24:09 -05:00
Wing Lian
22319182ab fix for auto_map check when using remote code and multipack for models like deepseek (#2151) [skip ci] 2024-12-07 22:23:52 -05:00
Wing Lian
440aab8a6f add --version support to axolotl cli (#2152) [skip ci] 2024-12-07 22:23:33 -05:00
41 changed files with 1175 additions and 235 deletions

View File

@@ -13,10 +13,13 @@ jobs:
permissions:
contents: write
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Create release
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: gh release create "$GITHUB_REF_NAME" # GITHUB_REF_NAME is the tag name in `on.push.tags` workflows
run: gh release create "$GITHUB_REF_NAME" --generate-notes
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
@@ -38,7 +41,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install wheel packaging
pip3 install -e .
pip3 install --no-build-isolation -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Extract tag name

View File

@@ -60,11 +60,15 @@ jobs:
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install -U -e .
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help

View File

@@ -78,19 +78,23 @@ jobs:
- name: Install dependencies
run: |
pip3 show torch
pip3 install -U -e .
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Run tests
run: |
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest -v tests/patched/
- name: cleanup pip cache
run: |
@@ -120,7 +124,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
pip3 install --upgrade packaging setuptools setuptools_scm build wheel
- name: Install PyTorch
run: |
@@ -129,20 +133,24 @@ jobs:
- name: Install dependencies
run: |
pip3 show torch
python3 setup.py sdist
pip3 install dist/axolotl*.tar.gz
python -m build --no-isolation --sdist
pip3 install --no-build-isolation dist/axolotl*.tar.gz
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Run tests
run: |
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest -v tests/patched/
- name: cleanup pip cache
run: |

View File

@@ -1,4 +1,5 @@
include requirements.txt
include README.md
include LICENSE
include src/setuptools_axolotl_dynamic_dependencies.py
recursive-include axolotl *.py

104
README.md
View File

@@ -10,9 +10,13 @@
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests">
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a>
<br/>
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors"><img src="https://img.shields.io/github/contributors-anon/axolotl-ai-cloud/axolotl?color=yellow&style=flat-square" alt="contributors" style="height: 20px;"></a>
<img src="https://img.shields.io/github/stars/axolotl-ai-cloud/axolotl" alt="GitHub Repo stars">
</p>
<p align="center">
<br/>
<a href="https://discord.com/invite/HhrNrHJPRb"><img src="https://img.shields.io/badge/discord-7289da.svg?style=flat-square&logo=discord" alt="discord" style="height: 20px;"></a>
<a href="https://twitter.com/axolotl_ai"><img src="https://img.shields.io/twitter/follow/axolotl_ai?style=social" alt="twitter" style="height: 20px;"></a>
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
</p>
@@ -42,7 +46,8 @@ Features:
- [Axolotl](#axolotl)
- [Table of Contents](#table-of-contents)
- [Quickstart ⚡](#quickstart-)
- [Usage](#usage)
- [Edge Builds](#edge-builds-)
- [Axolotl CLI Usage](#axolotl-cli-usage)
- [Badge ❤🏷️](#badge-)
- [Contributing 🤝](#contributing-)
- [Sponsors 🤝❤](#sponsors-)
@@ -107,58 +112,49 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
**Requirements**: *Nvidia* GPU (Ampere architecture or newer for `bf16` and Flash Attention) or *AMD* GPU, Python >=3.10 and PyTorch >=2.3.1.
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# download examples and optionally deepspeed configs to the local path
axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL
# finetune using lora
axolotl train examples/llama-3/lora-1b.yml
```
### Edge Builds 🏎️
If you're looking for the latest features and updates between releases, you'll need to install
from source.
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging ninja
pip3 install -e '.[flash-attn,deepspeed]'
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
### Usage
```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml
# finetune lora
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./outputs/lora-out"
# gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
```
### Axolotl CLI
If you've installed this package using `pip` from source, we now support a new, more
streamlined CLI using [click](https://click.palletsprojects.com/en/stable/). Rewriting
the above commands:
### Axolotl CLI Usage
We now support a new, more streamlined CLI using [click](https://click.palletsprojects.com/en/stable/).
```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/openllama-3b/lora.yml
CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/llama-3/lora-1b.yml
# finetune lora
axolotl train examples/openllama-3b/lora.yml
axolotl train examples/llama-3/lora-1b.yml
# inference
axolotl inference examples/openllama-3b/lora.yml \
axolotl inference examples/llama-3/lora-1b.yml \
--lora-model-dir="./outputs/lora-out"
# gradio
axolotl inference examples/openllama-3b/lora.yml \
axolotl inference examples/llama-3/lora-1b.yml \
--lora-model-dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml
```
We've also added a new command for fetching `examples` and `deepspeed_configs` to your
@@ -175,6 +171,36 @@ axolotl fetch deepspeed_configs
axolotl fetch examples --dest path/to/folder
```
### Legacy Usage
<details>
<summary>Click to Expand</summary>
While the Axolotl CLI is the preferred method for interacting with axolotl, we
still support the legacy `-m axolotl.cli.*` usage.
```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/llama-3/lora-1b.yml
# finetune lora
accelerate launch -m axolotl.cli.train examples/llama-3/lora-1b.yml
# inference
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
--lora_model_dir="./outputs/lora-out"
# gradio
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
--lora_model_dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml
```
</details>
## Badge ❤🏷️
Building something cool with Axolotl? Consider adding a badge to your model card.
@@ -294,7 +320,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
3. Install Axolotl along with python dependencies
```bash
pip3 install packaging
pip3 install -e '.[flash-attn,deepspeed]'
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
4. (Optional) Login to Huggingface to use gated models/datasets.
```bash
@@ -373,7 +399,7 @@ Please use WSL or Docker!
Use the below instead of the install method in QuickStart.
```
pip3 install -e '.'
pip3 install --no-build-isolation -e '.'
```
More info: [mac.md](/docs/mac.qmd)

View File

@@ -31,9 +31,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -1,7 +1,10 @@
#!/bin/bash
set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -16,7 +16,7 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \

View File

@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image

View File

@@ -52,7 +52,7 @@ export GPU_ARCHS="gfx90a"
cd flash-attention
export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])')
patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch
pip install .
pip install --no-build-isolation .
```
### 6. Install Axolotl
@@ -63,7 +63,7 @@ Clone and install Axolotl:
git clone https://github.com/axolotl-ai-cloud/axolotl
cd axolotl
pip install packaging ninja
pip install -e .
pip install --no-build-isolation -e .
```
### 7. Apply xformers Workaround

View File

@@ -71,7 +71,7 @@ Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/us
```bash
pip3 install packaging
pip3 install -e '.[flash-attn,deepspeed]'
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
#### Remote Hosts
@@ -212,7 +212,7 @@ You will now be in the container. Next, perform an editable install of Axolotl:
```bash
pip3 install packaging
pip3 install -e '.[flash-attn,deepspeed]'
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
### Attach To Container

View File

@@ -52,6 +52,26 @@ datasets:
type: chat_template.argilla
```
#### KTO
```yaml
rl: kto
rl_beta: 0.5
kto_desirable_weight: 0.2
remove_unused_columns: false
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
type: llama3.ultra
split: train
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
```
#### Using local dataset files
```yaml
datasets:

View File

@@ -24,7 +24,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install axolotl[deepspeed]"
"!pip install --no-build-isolation axolotl[deepspeed]"
]
},
{

View File

@@ -0,0 +1,58 @@
base_model: NousResearch/Meta-Llama-3.1-8B
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
tensor_parallel: 'auto'
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -0,0 +1,74 @@
base_model: NousResearch/Llama-3.2-1B
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -0,0 +1,73 @@
base_model: NousResearch/Meta-Llama-3.1-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_modules_to_save:
- embed_tokens
- lm_head
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
tensor_parallel: 'auto'
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -0,0 +1,75 @@
base_model: meta-llama/Llama-3.2-1B
load_in_8bit: false
load_in_4bit: true
strict: false
rl: kto
rl_beta: 0.5
kto_desirable_weight: 0.2
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
type: llama3.ultra
split: train
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/qlora-out
remove_unused_columns: false
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: false # not supported with kto
eval_sample_packing: false
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 64
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 20
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Llama-3.2-1B
base_model: NousResearch/Llama-3.2-1B
load_in_8bit: false
load_in_4bit: true
@@ -22,7 +22,6 @@ pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj

View File

@@ -17,3 +17,10 @@ Homepage = "https://axolotl-ai-cloud.github.io/axolotl/"
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
[tool.setuptools_scm]
[tool.setuptools]
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
include-package-data = true
[tool.setuptools.cmdclass]
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"

View File

@@ -1,22 +1,30 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.0
triton>=2.3.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.0.post2
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.4.2
# END section
packaging==23.2
peft==0.14.0
transformers==4.47.0
transformers>=4.46.3
tokenizers>=0.20.1
bitsandbytes==0.45.0
accelerate==1.2.0
datasets==3.1.0
deepspeed==0.15.4
deepspeed==0.16.1
pydantic==2.6.3
addict
fire
PyYAML>=6.0
requests
flash-attn==2.7.0.post2
sentencepiece
wandb
einops
xformers>=0.0.23.post1
optimum==1.16.2
hf_transfer
colorama
@@ -31,11 +39,6 @@ art
gradio==3.50.2
tensorboard
python-dotenv==1.0.1
autoawq==0.2.7.post2
triton>=2.3.0
liger-kernel==0.4.2
mamba-ssm==1.2.0.post1
# remote filesystems
s3fs>=2024.5.0

View File

@@ -13,5 +13,5 @@ cd /workspace
rm -rf /workspace/axolotl
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip install --no-deps -e .
pip install --no-build-isolation --no-deps -e .
```

View File

@@ -1,7 +1,10 @@
"""setup.py for axolotl"""
import ast
import os
import platform
import re
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from setuptools import find_packages, setup
@@ -90,9 +93,24 @@ def parse_requirements():
return _install_requires, _dependency_links
def get_package_version():
with open(
Path(os.path.dirname(os.path.abspath(__file__)))
/ "src"
/ "axolotl"
/ "__init__.py",
"r",
encoding="utf-8",
) as fin:
version_match = re.search(r"^__version__\s*=\s*(.*)$", fin.read(), re.MULTILINE)
version_ = ast.literal_eval(version_match.group(1))
return version_
install_requires, dependency_links = parse_requirements()
setup(
version=get_package_version(),
package_dir={"": "src"},
packages=find_packages("src"),
install_requires=install_requires,
@@ -107,7 +125,7 @@ setup(
"flash-attn==2.7.0.post2",
],
"deepspeed": [
"deepspeed==0.15.4",
"deepspeed==0.16.1",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -1,8 +1,3 @@
"""Axolotl - Train and fine-tune large language models"""
try:
from importlib.metadata import version
__version__ = version("axolotl")
except ImportError:
__version__ = "unknown"
__version__ = "0.6.0"

View File

@@ -5,6 +5,7 @@ from typing import Optional
import click
import axolotl
from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
@@ -16,6 +17,7 @@ from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli():
"""Axolotl CLI - Train and fine-tune large language models"""

View File

@@ -22,6 +22,7 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch
import transformers
from datasets import Dataset
from packaging import version
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
@@ -973,7 +974,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs, start_time)
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
try:
return super().log(logs, start_time)
except TypeError:
return super().log(logs) # transformers<=4.46
return super().log(logs) # transformers<=4.46
def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
@@ -1165,9 +1172,13 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
@@ -1185,9 +1196,13 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
@@ -1232,9 +1247,13 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
@@ -1252,9 +1271,13 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
@@ -1266,9 +1289,12 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
class TrainerBuilderBase(abc.ABC):
@@ -1293,6 +1319,10 @@ class TrainerBuilderBase(abc.ABC):
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
os.environ["ACCELERATE_USE_TP"] = "true"
# self.model =
@property
def model_ref(self):
return self._model_ref
@@ -1342,8 +1372,6 @@ class TrainerBuilderBase(abc.ABC):
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from transformers.integrations.integration_utils import MLflowCallback
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
@@ -1351,7 +1379,6 @@ class TrainerBuilderBase(abc.ABC):
callbacks.extend(
[
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
MLflowCallback,
]
)
if self.cfg.use_comet and is_comet_available():

View File

@@ -0,0 +1,80 @@
"""
fix for FSDP optimizer save in trainer w 4.47.0
"""
import inspect
import logging
from transformers import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")
ORIGINAL_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
"""
PATCHED_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
training_loop = get_training_loop_code()
training_loop, _ = detab_code(training_loop)
return ORIGINAL_TRAINER_CODE in training_loop
def patch_training_loop_for_fsdp():
"""
monkeypatch for fixing the training loop for fsdp with optimizer save
"""
try:
training_loop = get_training_loop_code()
except OSError:
return
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
training_loop
)
training_loop, _ = detab_code(training_loop)
if ORIGINAL_TRAINER_CODE not in training_loop:
return
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
training_loop = training_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -5,8 +5,7 @@ see https://github.com/huggingface/transformers/pull/35128
import inspect
import logging
from transformers import LlamaForCausalLM
from transformers.trainer import Trainer
from transformers import LlamaForCausalLM, Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
@@ -205,3 +204,87 @@ def patch_forward_for_ga():
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
)
ORIGINAL_TRAINER_CODE = """
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
PATCHED_TRAINER_CODE = """
disable_deepspeed_no_sync = (
self.accelerator.distributed_type == DistributedType.DEEPSPEED
# and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients()
)
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
training_loop = get_training_loop_code()
training_loop, _ = detab_code(training_loop)
return ORIGINAL_TRAINER_CODE in training_loop
def patch_training_loop_for_deepspeed_0_16_x():
"""
monkeypatch for fixing the training loop for deepspeed GA
see https://github.com/huggingface/transformers/pull/35157
"""
try:
training_loop = get_training_loop_code()
except OSError:
return
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
training_loop
)
training_loop, _ = detab_code(training_loop)
if ORIGINAL_TRAINER_CODE not in training_loop:
return
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
training_loop = training_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -28,6 +28,8 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
:return:
"""
max_length = self.prompter.max_length
self.messages = "chosen_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
@@ -39,6 +41,16 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt)
if len(chosen_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
)
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
chosen_tokenized["attention_mask"] = chosen_tokenized["attention_mask"][
:max_length
]
self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
@@ -52,6 +64,18 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
)
rejected_tokenized = super().tokenize_prompt(prompt)
if len(rejected_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
)
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
:max_length
]
rejected_tokenized["attention_mask"] = rejected_tokenized["attention_mask"][
:max_length
]
return {
"input_ids_chosen": chosen_tokenized["input_ids"],
"attention_mask_chosen": chosen_tokenized["attention_mask"],
@@ -80,9 +104,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1
if not cfg.reward_model
else cfg.sequence_len,
"max_length": (
cfg.sequence_len + 1 if not cfg.reward_model else cfg.sequence_len
),
}
strategy_params = {

View File

@@ -42,6 +42,7 @@ class ChatTemplatePrompter(Prompter):
"gpt": "assistant",
"system": "system",
}
self.message_field_role = message_field_role
self.message_field_content = message_field_content
self.message_field_training = message_field_training
@@ -53,21 +54,9 @@ class ChatTemplatePrompter(Prompter):
self.drop_system_message = drop_system_message
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
turns = [
{
"role": self.roles[t[self.message_field_role]],
"content": t[self.message_field_content],
"training": t.get(self.message_field_training, None),
}
for t in conversation
]
if self.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
if self.processor:
text = self.processor.apply_chat_template(
turns,
conversation,
chat_template=self.chat_template,
tokenize=False,
add_generation_prompt=add_generation_prompt,
@@ -76,8 +65,6 @@ class ChatTemplatePrompter(Prompter):
text=text,
images=images,
return_tensors="pt",
truncation=True,
max_length=self.max_length,
)
# workaround since processor works in batches instead of single examples
for k, val in batch.items():
@@ -88,9 +75,7 @@ class ChatTemplatePrompter(Prompter):
return batch
return self.tokenizer.apply_chat_template(
turns,
truncation=True,
max_length=self.max_length,
conversation,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
)
@@ -215,7 +200,14 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
train_on_eos=None,
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.roles_to_train = roles_to_train if roles_to_train is not None else []
self.roles_to_train = []
if roles_to_train:
# map roles if exist in prompter.roles else use the role as is
self.roles_to_train = [
prompter.roles.get(role, role) for role in roles_to_train
]
self.train_on_eos = train_on_eos
self.images = "images"
@@ -262,30 +254,28 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return tokenized_prompt
turns = prompt[self.messages]
turns = self.get_conversation_thread(prompt)
input_ids = self.prompter.build_prompt(turns)
labels = [IGNORE_TOKEN_ID] * len(input_ids)
last_eos_idx = -1
for index, turn in enumerate(turns):
role = turn.get(self.prompter.message_field_role)
content = turn.get(self.prompter.message_field_content)
train_turn = turn.get(self.prompter.message_field_training)
train_detail = turn.get(self.prompter.message_field_training_detail)
role = turn.get("role")
content = turn.get("content")
train_turn = turn.get("training")
train_detail = turn.get("training_detail")
LOG.debug(
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
)
should_train = (
train_turn
if train_turn is not None
else (
bool(train_detail is not None)
if train_detail is not None
else self.train_on_inputs or role in self.roles_to_train
)
)
should_train = None
if train_turn is not None:
should_train = train_turn
elif train_detail is not None:
should_train = bool(train_detail)
else:
should_train = self.train_on_inputs or role in self.roles_to_train
LOG.debug(f"Should train: {should_train}")
@@ -293,6 +283,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
conversation_ids=input_ids, turn=index, turn_content=turn
)
if turn_start_idx == -1 or turn_end_idx == -1:
LOG.warning(f"Failed to find boundaries for turn {index}")
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
@@ -313,7 +306,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
labels[turn_start_idx:turn_end_idx] = input_ids[
turn_start_idx:turn_end_idx
]
LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}")
LOG.debug(
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
)
LOG.debug(f"Labels after processing turn {index}: {labels}")
@@ -351,52 +346,73 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return i
return -1
def find_turn(self, conversation_ids, turn, turn_content):
def find_turn(self, conversation_ids: list[int], turn: int, turn_content: dict):
"""
Locate the starting and ending indices of the specified turn in a conversation.
Args:
conversation_ids (list[int]): Token IDs representing the conversation.
turn (int): The turn number to locate (based on EOS tokens).
turn_content (str): String containing the content of the turn.
Returns:
tuple: (start_idx, end_idx) indices of the start and end of the turn content.
Returns (-1, -1) if the turn content is not found.
"""
content = turn_content.get(self.prompter.message_field_content, "")
content = turn_content.get("content")
content_ids = self.tokenizer.encode(content, add_special_tokens=False)
eos_token_id = self.tokenizer.eos_token_id
eos_count = 0
start_search_idx = 0
LOG.debug(f"content_ids (length {len(content_ids)}): {content_ids}")
# Locate the starting index after the specified number of EOS tokens
for i, token_id in enumerate(conversation_ids):
if token_id == eos_token_id:
eos_count += 1
if eos_count == turn:
start_search_idx = (
i + 1
) # Start searching after the specified turn's EOS token
break
if not content_ids:
LOG.warning(f"Empty content for turn {turn}")
return -1, -1
# Find the start index of the content within the conversation
start_idx = -1
for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1):
if conversation_ids[i : i + len(content_ids)] == content_ids:
start_idx = i
break
if start_idx != -1:
end_idx = start_idx + len(content_ids)
# For first turn, start from beginning
if turn == 0:
start_search_idx = 0
else:
end_idx = -1
# For subsequent turns, find the previous EOS token
eos_token_id = self.tokenizer.eos_token_id
eos_count = 0
start_search_idx = 0
return start_idx, end_idx
for i, token_id in enumerate(conversation_ids):
if token_id == eos_token_id:
eos_count += 1
if eos_count == turn: # Find the nth EOS token where n = turn
start_search_idx = i + 1
break
# we can optimize this to only search for a few tokens from start_search_idx
# but it would risk missing the content if it's not found within the first few tokens or
# if start_search_idx cannot be found above.
last_index = len(conversation_ids) - len(content_ids) + 1
if last_index < start_search_idx:
LOG.warning(
f"last_index to search is less than start_search_idx for turn {turn}"
)
return -1, -1
# Search for content starting from start_search_idx
first_elem = content_ids[0]
for i in range(start_search_idx, last_index):
# Quick check of first element before doing full comparison
if conversation_ids[i] == first_elem:
# Check if the rest of the content matches
if conversation_ids[i : i + len(content_ids)] == content_ids:
LOG.debug(f"Found turn {turn} content at position {i}")
return i, i + len(content_ids)
return -1, -1
def get_conversation_thread(self, prompt):
return prompt[self.messages]
turns = [
{
"role": self.prompter.roles[t[self.prompter.message_field_role]],
"content": t[self.prompter.message_field_content],
"training": t.get(self.prompter.message_field_training),
"training_detail": t.get(self.prompter.message_field_training_detail),
}
for t in prompt[self.messages]
]
if self.prompter.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
return turns
def get_images(self, prompt):
return prompt.get(self.images, None)

View File

@@ -259,14 +259,7 @@ def train(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id:
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError
try:
# Check to make sure the base model is from HuggingFace not a local directory
hf_api = HfApi()
hf_api.model_info(cfg.base_model)
model_card_kwarg = {
"model_name": cfg.output_dir.lstrip("./")
.encode("utf-8")
@@ -274,16 +267,22 @@ def train(
}
if cfg.datasets is not None:
if cfg.rl is not None or cfg.reward_model:
model_card_kwarg["dataset_name"] = [
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
if dataset_tags:
# guard as create_model_card may fail if dataset_tags is empty list
model_card_kwarg["dataset_name"] = dataset_tags
else:
model_card_kwarg["dataset_tags"] = [
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
if dataset_tags:
# guard as create_model_card may fail if dataset_tags is empty list
model_card_kwarg["dataset_tags"] = dataset_tags
trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError, RepositoryNotFoundError):
except (AttributeError, UnicodeDecodeError):
pass
elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated

View File

@@ -393,7 +393,7 @@ class ModelInputConfig(BaseModel):
default=None, json_schema_extra={"description": "transformers processor class"}
)
trust_remote_code: Optional[bool] = None
tensor_parallel: Optional[Union[Literal["auto"], bool]] = "auto"
model_kwargs: Optional[Dict[str, Any]] = None
@field_validator("trust_remote_code")
@@ -1475,6 +1475,27 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_kto_config(cls, data):
if data.get("rl") == "kto":
if data.get("sample_packing") or data.get("eval_sample_packing"):
raise ValueError("sample_packing is not supported with kto")
if data.get("remove_unused_columns") is not False:
raise ValueError("Set `remove_unused_columns: False` when using kto")
if data.get("gradient_checkpointing") and not (
data.get("gradient_checkpointing_kwargs")
and isinstance(data.get("gradient_checkpointing_kwargs"), dict)
and data["gradient_checkpointing_kwargs"].get("use_reentrant")
):
raise ValueError(
"Set `gradient_checkpointing_kwargs: {use_reentrant: true}` for when kto is enabled"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""

View File

@@ -380,6 +380,19 @@ class ModelLoader:
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg)
if self.cfg.fsdp:
from axolotl.monkeypatch.trainer_fsdp_optim import (
patch_training_loop_for_fsdp,
)
patch_training_loop_for_fsdp()
elif self.cfg.deepspeed and self.cfg.gradient_accumulation_steps > 1:
from axolotl.monkeypatch.trainer_grad_accum import (
patch_training_loop_for_deepspeed_0_16_x,
)
patch_training_loop_for_deepspeed_0_16_x()
if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
@@ -406,10 +419,14 @@ class ModelLoader:
and self.cfg.flash_attention
and self.cfg.sample_packing
):
has_remote_code = (
"auto_map" in self.model_config
and "AutoModelForCausalLM" in self.model_config["auto_map"]
)
if "auto_map" in self.model_config:
try:
auto_map_config = self.model_config["auto_map"]
except TypeError:
auto_map_config = self.model_config.auto_map
has_remote_code = "AutoModelForCausalLM" in auto_map_config
else:
has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is False:
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
has_remote_code = self.cfg.trust_remote_code
@@ -1170,9 +1187,15 @@ class ModelLoader:
gc.collect()
torch.cuda.empty_cache()
self.post_loading_set_env()
# TODO resume_from_checkpoint handling
return self.model, lora_config
def post_loading_set_env(self):
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
os.environ["ACCELERATE_USE_TP"] = "true"
def load_model(
cfg: DictDefault,

View File

View File

@@ -0,0 +1,104 @@
"""
dynamic requirements for axolotl
"""
import platform
import re
from importlib.metadata import PackageNotFoundError, version
from setuptools.command.build_py import build_py as _build_py
# pylint: disable=duplicate-code
def parse_requirements():
_install_requires = []
_dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [r.strip() for r in requirements_file.readlines()]
for line in lines:
is_extras = (
"flash-attn" in line
or "flash-attention" in line
or "deepspeed" in line
or "mamba-ssm" in line
or "lion-pytorch" in line
)
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
elif not is_extras and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version))
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
try:
torch_version = version("torch")
except PackageNotFoundError:
torch_version = "2.5.1"
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if version_match:
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = (
int(patch) if patch is not None else 0
) # Default patch to 0 if not present
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers==0.0.28.post3")
_install_requires.pop(_install_requires.index(autoawq_version))
elif (major, minor) >= (2, 4):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1")
else:
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1")
except PackageNotFoundError:
pass
return _install_requires, _dependency_links
class BuildPyCommand(_build_py):
"""
custom build_py command to parse dynamic requirements
"""
def finalize_options(self):
super().finalize_options()
install_requires, _ = parse_requirements()
self.distribution.install_requires = install_requires

View File

@@ -0,0 +1,10 @@
"""pytest tests for axolotl CLI --version"""
from axolotl.cli.main import cli
def test_print_version(cli_runner):
"""Test that version is printed when --version is used."""
result = cli_runner.invoke(cli, ["--version"])
assert result.exit_code == 0
assert "axolotl, version " in result.output

View File

@@ -119,25 +119,52 @@ def temp_dir():
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from transformers import Trainer
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaFlashAttention2,
LlamaForCausalLM,
)
original_fa2_forward = LlamaFlashAttention2.forward
original_llama_attn_forward = LlamaAttention.forward
original_llama_forward = LlamaForCausalLM.forward
original_trainer_inner_training_loop = (
Trainer._inner_training_loop # pylint: disable=protected-access
)
original_trainer_training_step = Trainer.training_step
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward
LlamaAttention.forward = original_llama_attn_forward
LlamaForCausalLM.forward = original_llama_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access
original_trainer_inner_training_loop
)
Trainer.training_step = original_trainer_training_step
# Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
("transformers.models.llama",),
(
"transformers.models.llama.modeling_llama",
["LlamaFlashAttention2", "LlamaAttention"],
),
("transformers.trainer",),
("transformers", ["Trainer"]),
("transformers.loss.loss_utils",),
]
for module_name_tuple in modules_to_reset:
module_name = module_name_tuple[0]
module = importlib.import_module(module_name)
sys.modules[module_name] = module
importlib.reload(sys.modules[module_name])
spec = importlib.util.spec_from_file_location(
module_name, sys.modules[module_name].__file__
)
sys.modules[module_name] = importlib.util.module_from_spec(spec)
spec.loader.exec_module(sys.modules[module_name])
sys.modules[module_name] = importlib.reload(sys.modules[module_name])
if len(module_name_tuple) > 1:
module_globals = module_name_tuple[1]
for module_global in module_globals:

View File

@@ -71,7 +71,11 @@ class TestCutCrossEntropyIntegration:
@pytest.mark.parametrize(
"attention_type",
["flash_attention", "sdp_attention", "xformers_attention"],
[
"flash_attention",
"sdp_attention",
# "xformers_attention",
],
)
def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type):
cfg = DictDefault(

View File

@@ -9,6 +9,7 @@ from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from e2e.utils import check_tensorboard
from huggingface_hub import snapshot_download
from transformers.testing_utils import get_torch_dist_unique_port
@@ -53,7 +54,7 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -61,6 +62,7 @@ class TestMultiGPULlama:
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
}
)
@@ -83,9 +85,13 @@ class TestMultiGPULlama:
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 4],
[1, 2],
)
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
@@ -112,14 +118,15 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 15,
"micro_batch_size": 4,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
}
)
@@ -142,6 +149,10 @@ class TestMultiGPULlama:
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
def test_dpo_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -180,7 +191,7 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -189,6 +200,7 @@ class TestMultiGPULlama:
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
}
)
@@ -211,6 +223,10 @@ class TestMultiGPULlama:
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
def test_dpo_qlora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -249,8 +265,8 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 15,
"micro_batch_size": 4,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"warmup_steps": 0,
@@ -258,6 +274,7 @@ class TestMultiGPULlama:
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
}
)
@@ -280,9 +297,13 @@ class TestMultiGPULlama:
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 4],
[1, 2],
)
def test_fsdp(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
@@ -301,8 +322,8 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 10,
"micro_batch_size": 4,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
@@ -323,6 +344,7 @@ class TestMultiGPULlama:
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
"use_tensorboard": True,
}
)
@@ -345,6 +367,10 @@ class TestMultiGPULlama:
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
"fsdp_state_dict_type",
["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
@@ -368,7 +394,7 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -390,6 +416,7 @@ class TestMultiGPULlama:
"fsdp_state_dict_type": fsdp_state_dict_type,
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
"use_tensorboard": True,
}
)
@@ -412,6 +439,10 @@ class TestMultiGPULlama:
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
def test_fsdp_qlora_prequant_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -444,7 +475,7 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -466,6 +497,7 @@ class TestMultiGPULlama:
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
"use_tensorboard": True,
}
)
@@ -488,12 +520,41 @@ class TestMultiGPULlama:
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 4],
[1, 2],
)
def test_ds_zero3_packed(self, temp_dir, gradient_accumulation_steps):
@pytest.mark.parametrize(
"deepspeed",
[
"deepspeed_configs/zero3_bf16.json",
"deepspeed_configs/zero3_bf16_cpuoffload_all.json",
# "deepspeed_configs/zero3_bf16_cpuoffload_params.json",
],
)
@pytest.mark.parametrize(
"qlora",
[True, False],
)
def test_ds_zero3_packed(
self, temp_dir, gradient_accumulation_steps, deepspeed, qlora
):
# pylint: disable=duplicate-code
if qlora:
adapter = {
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"load_in_4bit": True,
}
else:
adapter = {}
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -511,15 +572,17 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 15,
"micro_batch_size": 4,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
"deepspeed": str(AXOLOTL_ROOT / deepspeed),
"use_tensorboard": True,
**adapter,
}
)
@@ -542,19 +605,35 @@ class TestMultiGPULlama:
]
)
def test_ds_zero3_qlora_packed(self, temp_dir):
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],
)
@pytest.mark.parametrize(
"qlora",
[True, False],
)
def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps, qlora):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_4bit": True,
if qlora:
adapter = {
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"load_in_4bit": True,
}
else:
adapter = {}
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
@@ -568,15 +647,17 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
"use_tensorboard": True,
**adapter,
}
)
@@ -598,3 +679,82 @@ class TestMultiGPULlama:
str(Path(temp_dir) / "config.yaml"),
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],
)
@pytest.mark.parametrize(
"qlora",
[True, False],
)
def test_ds_zero1_packed(self, temp_dir, gradient_accumulation_steps, qlora):
# pylint: disable=duplicate-code
if qlora:
adapter = {
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"load_in_4bit": True,
}
else:
adapter = {}
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"use_tensorboard": True,
**adapter,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import logging
import os
from importlib import reload
from pathlib import Path
import pytest
@@ -22,14 +21,6 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@pytest.fixture(autouse=True)
def reload_transformers():
import transformers.models.llama.modeling_llama
yield
reload(transformers.models.llama.modeling_llama)
class TestFAXentropyLlama:
"""
Test case for Llama models using LoRA w multipack

View File

@@ -7,6 +7,7 @@ import os
import unittest
from pathlib import Path
import pytest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
@@ -21,6 +22,7 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@pytest.mark.skip("FIXME, mostly underused functionality")
class TestFusedLlama(unittest.TestCase):
"""
Test case for Llama models using Fused layers