Compare commits
58 Commits
pixtral_in
...
activation
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ac9cbebb9 | ||
|
|
15f2fa4c8e | ||
|
|
43a2f9a155 | ||
|
|
8b79f1cbf6 | ||
|
|
3872d5eaed | ||
|
|
02629c7cdf | ||
|
|
78a4aa86d6 | ||
|
|
d009ead101 | ||
|
|
6aa31b44c6 | ||
|
|
9001859b0b | ||
|
|
34d3c8dcfb | ||
|
|
ab4b32187d | ||
|
|
5d6b088997 | ||
|
|
3862267040 | ||
|
|
c78de6f214 | ||
|
|
b1e8286c57 | ||
|
|
40907c6887 | ||
|
|
6a342feda2 | ||
|
|
0c25bc07a2 | ||
|
|
343a4d8855 | ||
|
|
393853751e | ||
|
|
1302e31049 | ||
|
|
be5f554a62 | ||
|
|
22319182ab | ||
|
|
440aab8a6f | ||
|
|
5bef19064b | ||
|
|
743ba62bd5 | ||
|
|
f9a7748bd8 | ||
|
|
5e9fa33f3d | ||
|
|
08fa133177 | ||
|
|
6b3058b2dc | ||
|
|
5726141c4e | ||
|
|
2f3ebbc44f | ||
|
|
fc973f4322 | ||
|
|
e399ba533e | ||
|
|
4baf8e5e96 | ||
|
|
d7d2fd366e | ||
|
|
e2882dd749 | ||
|
|
a1790f2652 | ||
|
|
418ad2b586 | ||
|
|
d87df2c776 | ||
|
|
1ef70312ba | ||
|
|
81ef3e45f7 | ||
|
|
bd8436bc6e | ||
|
|
fc6188cd76 | ||
|
|
b9bb02406a | ||
|
|
ff4794cd8e | ||
|
|
822c904092 | ||
|
|
d5f58b6509 | ||
|
|
9f6d0b5587 | ||
|
|
53963c792c | ||
|
|
a4f4a56d77 | ||
|
|
ce5bcff750 | ||
|
|
b620ed94d0 | ||
|
|
5f1d98e8fc | ||
|
|
1cf7075d18 | ||
|
|
f4cabc2351 | ||
|
|
6e0fb4a6b2 |
7
.github/workflows/pypi.yml
vendored
7
.github/workflows/pypi.yml
vendored
@@ -13,10 +13,13 @@ jobs:
|
|||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
steps:
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Create release
|
- name: Create release
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
run: gh release create "$GITHUB_REF_NAME" # GITHUB_REF_NAME is the tag name in `on.push.tags` workflows
|
run: gh release create "$GITHUB_REF_NAME" --generate-notes
|
||||||
pypi-publish:
|
pypi-publish:
|
||||||
name: Upload release to PyPI
|
name: Upload release to PyPI
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
@@ -38,7 +41,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 install wheel packaging
|
pip3 install wheel packaging
|
||||||
pip3 install -e .
|
pip3 install --no-build-isolation -e .
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
- name: Extract tag name
|
- name: Extract tag name
|
||||||
|
|||||||
25
.github/workflows/tests-nightly.yml
vendored
25
.github/workflows/tests-nightly.yml
vendored
@@ -23,9 +23,15 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
max-parallel: 2
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.10", "3.11"]
|
python_version: ["3.10", "3.11"]
|
||||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
|
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
|
||||||
|
exclude:
|
||||||
|
- python_version: "3.10"
|
||||||
|
pytorch_version: "2.4.1"
|
||||||
|
- python_version: "3.10"
|
||||||
|
pytorch_version: "2.5.1"
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -38,6 +44,11 @@ jobs:
|
|||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
|
|
||||||
|
- name: upgrade pip
|
||||||
|
run: |
|
||||||
|
pip3 install --upgrade pip
|
||||||
|
pip3 install --upgrade packaging setuptools wheel
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
|
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
|
||||||
@@ -54,13 +65,23 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging
|
pip3 install --upgrade packaging
|
||||||
pip3 install -U -e .
|
pip3 install --no-build-isolation -U -e .
|
||||||
|
python scripts/unsloth_install.py | sh
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
|
- name: Make sure PyTorch version wasn't clobbered
|
||||||
|
run: |
|
||||||
|
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||||
|
|
||||||
|
- name: Ensure axolotl CLI was installed
|
||||||
|
run: |
|
||||||
|
axolotl --help
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest --ignore=tests/e2e/ tests/
|
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||||
|
pytest tests/patched/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
41
.github/workflows/tests.yml
vendored
41
.github/workflows/tests.yml
vendored
@@ -10,6 +10,7 @@ on:
|
|||||||
- '.github/workflows/*.yml'
|
- '.github/workflows/*.yml'
|
||||||
- 'requirements-tests.txt'
|
- 'requirements-tests.txt'
|
||||||
- 'cicd/cicd.sh'
|
- 'cicd/cicd.sh'
|
||||||
|
- 'cicd/Dockerfile.jinja'
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- '**.py'
|
- '**.py'
|
||||||
@@ -17,6 +18,7 @@ on:
|
|||||||
- '.github/workflows/*.yml'
|
- '.github/workflows/*.yml'
|
||||||
- 'requirements-tests.txt'
|
- 'requirements-tests.txt'
|
||||||
- 'cicd/cicd.sh'
|
- 'cicd/cicd.sh'
|
||||||
|
- 'cicd/Dockerfile.jinja'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
# Cancel jobs on the same ref if a new one is triggered
|
# Cancel jobs on the same ref if a new one is triggered
|
||||||
@@ -43,9 +45,15 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
max-parallel: 2
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.10", "3.11"]
|
python_version: ["3.10", "3.11"]
|
||||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
|
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
|
||||||
|
exclude:
|
||||||
|
- python_version: "3.10"
|
||||||
|
pytorch_version: "2.4.1"
|
||||||
|
- python_version: "3.10"
|
||||||
|
pytorch_version: "2.5.1"
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -70,14 +78,23 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 show torch
|
pip3 show torch
|
||||||
pip3 install -U -e .
|
pip3 install --no-build-isolation -U -e .
|
||||||
python scripts/unsloth_install.py | sh
|
python scripts/unsloth_install.py | sh
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
|
- name: Make sure PyTorch version wasn't clobbered
|
||||||
|
run: |
|
||||||
|
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||||
|
|
||||||
|
- name: Ensure axolotl CLI was installed
|
||||||
|
run: |
|
||||||
|
axolotl --help
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -n8 --ignore=tests/e2e/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||||
|
pytest -v tests/patched/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
@@ -88,6 +105,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
max-parallel: 1
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11"]
|
python_version: ["3.11"]
|
||||||
pytorch_version: ["2.4.1", "2.5.1"]
|
pytorch_version: ["2.4.1", "2.5.1"]
|
||||||
@@ -106,7 +124,7 @@ jobs:
|
|||||||
- name: upgrade pip
|
- name: upgrade pip
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging setuptools wheel
|
pip3 install --upgrade packaging setuptools setuptools_scm build wheel
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
@@ -115,13 +133,24 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 show torch
|
pip3 show torch
|
||||||
python3 setup.py sdist
|
python -m build --no-isolation --sdist
|
||||||
pip3 install dist/axolotl*.tar.gz
|
pip3 install --no-build-isolation dist/axolotl*.tar.gz
|
||||||
|
python scripts/unsloth_install.py | sh
|
||||||
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
|
- name: Make sure PyTorch version wasn't clobbered
|
||||||
|
run: |
|
||||||
|
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||||
|
|
||||||
|
- name: Ensure axolotl CLI was installed
|
||||||
|
run: |
|
||||||
|
axolotl --help
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -n8 --ignore=tests/e2e/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||||
|
pytest -v tests/patched/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
include requirements.txt
|
include requirements.txt
|
||||||
include README.md
|
include README.md
|
||||||
include LICENSE
|
include LICENSE
|
||||||
|
include src/setuptools_axolotl_dynamic_dependencies.py
|
||||||
recursive-include axolotl *.py
|
recursive-include axolotl *.py
|
||||||
|
|||||||
289
README.md
289
README.md
@@ -10,9 +10,13 @@
|
|||||||
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
|
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
|
||||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests">
|
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests">
|
||||||
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a>
|
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a>
|
||||||
|
<br/>
|
||||||
|
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors"><img src="https://img.shields.io/github/contributors-anon/axolotl-ai-cloud/axolotl?color=yellow&style=flat-square" alt="contributors" style="height: 20px;"></a>
|
||||||
<img src="https://img.shields.io/github/stars/axolotl-ai-cloud/axolotl" alt="GitHub Repo stars">
|
<img src="https://img.shields.io/github/stars/axolotl-ai-cloud/axolotl" alt="GitHub Repo stars">
|
||||||
</p>
|
<br/>
|
||||||
<p align="center">
|
<a href="https://discord.com/invite/HhrNrHJPRb"><img src="https://img.shields.io/badge/discord-7289da.svg?style=flat-square&logo=discord" alt="discord" style="height: 20px;"></a>
|
||||||
|
<a href="https://twitter.com/axolotl_ai"><img src="https://img.shields.io/twitter/follow/axolotl_ai?style=social" alt="twitter" style="height: 20px;"></a>
|
||||||
|
<br/>
|
||||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
|
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
|
||||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
|
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
|
||||||
</p>
|
</p>
|
||||||
@@ -41,9 +45,13 @@ Features:
|
|||||||
## Table of Contents
|
## Table of Contents
|
||||||
- [Axolotl](#axolotl)
|
- [Axolotl](#axolotl)
|
||||||
- [Table of Contents](#table-of-contents)
|
- [Table of Contents](#table-of-contents)
|
||||||
- [Axolotl supports](#axolotl-supports)
|
|
||||||
- [Quickstart ⚡](#quickstart-)
|
- [Quickstart ⚡](#quickstart-)
|
||||||
- [Usage](#usage)
|
- [Edge Builds](#edge-builds-)
|
||||||
|
- [Axolotl CLI Usage](#axolotl-cli-usage)
|
||||||
|
- [Badge ❤🏷️](#badge-️)
|
||||||
|
- [Contributing 🤝](#contributing-)
|
||||||
|
- [Sponsors 🤝❤](#sponsors-)
|
||||||
|
- [Axolotl supports](#axolotl-supports)
|
||||||
- [Advanced Setup](#advanced-setup)
|
- [Advanced Setup](#advanced-setup)
|
||||||
- [Environment](#environment)
|
- [Environment](#environment)
|
||||||
- [Docker](#docker)
|
- [Docker](#docker)
|
||||||
@@ -75,14 +83,6 @@ Features:
|
|||||||
- [Tokenization Mismatch b/w Inference \& Training](#tokenization-mismatch-bw-inference--training)
|
- [Tokenization Mismatch b/w Inference \& Training](#tokenization-mismatch-bw-inference--training)
|
||||||
- [Debugging Axolotl](#debugging-axolotl)
|
- [Debugging Axolotl](#debugging-axolotl)
|
||||||
- [Need help? 🙋](#need-help-)
|
- [Need help? 🙋](#need-help-)
|
||||||
- [Badge ❤🏷️](#badge-️)
|
|
||||||
- [Community Showcase](#community-showcase)
|
|
||||||
- [Contributing 🤝](#contributing-)
|
|
||||||
- [Sponsors 🤝❤](#sponsors-)
|
|
||||||
- [💎 Diamond Sponsors - Contact directly](#-diamond-sponsors---contact-directly)
|
|
||||||
- [🥇 Gold Sponsors - $5000/mo](#-gold-sponsors---5000mo)
|
|
||||||
- [🥈 Silver Sponsors - $1000/mo](#-silver-sponsors---1000mo)
|
|
||||||
- [🥉 Bronze Sponsors - $500/mo](#-bronze-sponsors---500mo)
|
|
||||||
|
|
||||||
</td>
|
</td>
|
||||||
<td>
|
<td>
|
||||||
@@ -105,6 +105,148 @@ Features:
|
|||||||
</tr>
|
</tr>
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
|
## Quickstart ⚡
|
||||||
|
|
||||||
|
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
|
||||||
|
|
||||||
|
**Requirements**: *Nvidia* GPU (Ampere architecture or newer for `bf16` and Flash Attention) or *AMD* GPU, Python >=3.10 and PyTorch >=2.3.1.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
|
|
||||||
|
# download examples and optionally deepspeed configs to the local path
|
||||||
|
axolotl fetch examples
|
||||||
|
axolotl fetch deepspeed_configs # OPTIONAL
|
||||||
|
|
||||||
|
# finetune using lora
|
||||||
|
axolotl train examples/llama-3/lora-1b.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
### Edge Builds 🏎️
|
||||||
|
|
||||||
|
If you're looking for the latest features and updates between releases, you'll need to install
|
||||||
|
from source.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
|
cd axolotl
|
||||||
|
pip3 install packaging ninja
|
||||||
|
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Axolotl CLI Usage
|
||||||
|
We now support a new, more streamlined CLI using [click](https://click.palletsprojects.com/en/stable/).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# preprocess datasets - optional but recommended
|
||||||
|
CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/llama-3/lora-1b.yml
|
||||||
|
|
||||||
|
# finetune lora
|
||||||
|
axolotl train examples/llama-3/lora-1b.yml
|
||||||
|
|
||||||
|
# inference
|
||||||
|
axolotl inference examples/llama-3/lora-1b.yml \
|
||||||
|
--lora-model-dir="./outputs/lora-out"
|
||||||
|
|
||||||
|
# gradio
|
||||||
|
axolotl inference examples/llama-3/lora-1b.yml \
|
||||||
|
--lora-model-dir="./outputs/lora-out" --gradio
|
||||||
|
|
||||||
|
# remote yaml files - the yaml config can be hosted on a public URL
|
||||||
|
# Note: the yaml config must directly link to the **raw** yaml
|
||||||
|
axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
We've also added a new command for fetching `examples` and `deepspeed_configs` to your
|
||||||
|
local machine. This will come in handy when installing `axolotl` from PyPI.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Fetch example YAML files (stores in "examples/" folder)
|
||||||
|
axolotl fetch examples
|
||||||
|
|
||||||
|
# Fetch deepspeed config files (stores in "deepspeed_configs/" folder)
|
||||||
|
axolotl fetch deepspeed_configs
|
||||||
|
|
||||||
|
# Optionally, specify a destination folder
|
||||||
|
axolotl fetch examples --dest path/to/folder
|
||||||
|
```
|
||||||
|
|
||||||
|
### Legacy Usage
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Click to Expand</summary>
|
||||||
|
|
||||||
|
While the Axolotl CLI is the preferred method for interacting with axolotl, we
|
||||||
|
still support the legacy `-m axolotl.cli.*` usage.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# preprocess datasets - optional but recommended
|
||||||
|
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/llama-3/lora-1b.yml
|
||||||
|
|
||||||
|
# finetune lora
|
||||||
|
accelerate launch -m axolotl.cli.train examples/llama-3/lora-1b.yml
|
||||||
|
|
||||||
|
# inference
|
||||||
|
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
|
||||||
|
--lora_model_dir="./outputs/lora-out"
|
||||||
|
|
||||||
|
# gradio
|
||||||
|
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
|
||||||
|
--lora_model_dir="./outputs/lora-out" --gradio
|
||||||
|
|
||||||
|
# remote yaml files - the yaml config can be hosted on a public URL
|
||||||
|
# Note: the yaml config must directly link to the **raw** yaml
|
||||||
|
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## Badge ❤🏷️
|
||||||
|
|
||||||
|
Building something cool with Axolotl? Consider adding a badge to your model card.
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
```
|
||||||
|
|
||||||
|
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
|
||||||
|
## Sponsors 🤝❤
|
||||||
|
|
||||||
|
If you love axolotl, consider sponsoring the project by reaching out directly to [wing@axolotl.ai](mailto:wing@axolotl.ai).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
- [Modal](https://modal.com/) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune LLM models, run protein folding simulations, and much more.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Contributing 🤝
|
||||||
|
|
||||||
|
Please read the [contributing guide](./.github/CONTRIBUTING.md)
|
||||||
|
|
||||||
|
Bugs? Please check the [open issues](https://github.com/axolotl-ai-cloud/axolotl/issues/bug) else create a new Issue.
|
||||||
|
|
||||||
|
PRs are **greatly welcome**!
|
||||||
|
|
||||||
|
Please run the quickstart instructions followed by the below to setup env:
|
||||||
|
```bash
|
||||||
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
pre-commit install
|
||||||
|
|
||||||
|
# test
|
||||||
|
pytest tests/
|
||||||
|
|
||||||
|
# optional: run against all files
|
||||||
|
pre-commit run --all-files
|
||||||
|
```
|
||||||
|
|
||||||
|
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
|
||||||
|
|
||||||
|
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors">
|
||||||
|
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
|
||||||
|
</a>
|
||||||
|
|
||||||
## Axolotl supports
|
## Axolotl supports
|
||||||
|
|
||||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||||
@@ -130,41 +272,6 @@ Features:
|
|||||||
❌: not supported
|
❌: not supported
|
||||||
❓: untested
|
❓: untested
|
||||||
|
|
||||||
## Quickstart ⚡
|
|
||||||
|
|
||||||
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
|
|
||||||
|
|
||||||
**Requirements**: Nvidia GPU (Ampere architecture or newer for `bf16` and Flash Attention), Python >=3.10 and PyTorch >=2.3.1.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl
|
|
||||||
cd axolotl
|
|
||||||
|
|
||||||
pip3 install packaging ninja
|
|
||||||
pip3 install -e '.[flash-attn,deepspeed]'
|
|
||||||
```
|
|
||||||
|
|
||||||
### Usage
|
|
||||||
```bash
|
|
||||||
# preprocess datasets - optional but recommended
|
|
||||||
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml
|
|
||||||
|
|
||||||
# finetune lora
|
|
||||||
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
|
||||||
|
|
||||||
# inference
|
|
||||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|
||||||
--lora_model_dir="./outputs/lora-out"
|
|
||||||
|
|
||||||
# gradio
|
|
||||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|
||||||
--lora_model_dir="./outputs/lora-out" --gradio
|
|
||||||
|
|
||||||
# remote yaml files - the yaml config can be hosted on a public URL
|
|
||||||
# Note: the yaml config must directly link to the **raw** yaml
|
|
||||||
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
|
|
||||||
```
|
|
||||||
|
|
||||||
## Advanced Setup
|
## Advanced Setup
|
||||||
|
|
||||||
### Environment
|
### Environment
|
||||||
@@ -213,7 +320,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
|||||||
3. Install Axolotl along with python dependencies
|
3. Install Axolotl along with python dependencies
|
||||||
```bash
|
```bash
|
||||||
pip3 install packaging
|
pip3 install packaging
|
||||||
pip3 install -e '.[flash-attn,deepspeed]'
|
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
4. (Optional) Login to Huggingface to use gated models/datasets.
|
4. (Optional) Login to Huggingface to use gated models/datasets.
|
||||||
```bash
|
```bash
|
||||||
@@ -292,7 +399,7 @@ Please use WSL or Docker!
|
|||||||
|
|
||||||
Use the below instead of the install method in QuickStart.
|
Use the below instead of the install method in QuickStart.
|
||||||
```
|
```
|
||||||
pip3 install -e '.'
|
pip3 install --no-build-isolation -e '.'
|
||||||
```
|
```
|
||||||
More info: [mac.md](/docs/mac.qmd)
|
More info: [mac.md](/docs/mac.qmd)
|
||||||
|
|
||||||
@@ -682,86 +789,6 @@ See [this debugging guide](docs/debugging.qmd) for tips on debugging Axolotl, al
|
|||||||
|
|
||||||
## Need help? 🙋
|
## Need help? 🙋
|
||||||
|
|
||||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we our community members can help you.
|
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where our community members can help you.
|
||||||
|
|
||||||
Need dedicated support? Please contact us at [✉️wing@openaccessaicollective.org](mailto:wing@openaccessaicollective.org) for dedicated support options.
|
Need dedicated support? Please contact us at [✉️wing@axolotl.ai](ailto:wing@axolotl.ai) for dedicated support options.
|
||||||
|
|
||||||
## Badge ❤🏷️
|
|
||||||
|
|
||||||
Building something cool with Axolotl? Consider adding a badge to your model card.
|
|
||||||
|
|
||||||
```markdown
|
|
||||||
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
|
|
||||||
```
|
|
||||||
|
|
||||||
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
|
|
||||||
|
|
||||||
## Community Showcase
|
|
||||||
|
|
||||||
Check out some of the projects and models that have been built using Axolotl! Have a model you'd like to add to our Community Showcase? Open a PR with your model.
|
|
||||||
|
|
||||||
Open Access AI Collective
|
|
||||||
- [Minotaur 13b](https://huggingface.co/openaccess-ai-collective/minotaur-13b-fixed)
|
|
||||||
- [Manticore 13b](https://huggingface.co/openaccess-ai-collective/manticore-13b)
|
|
||||||
- [Hippogriff 30b](https://huggingface.co/openaccess-ai-collective/hippogriff-30b-chat)
|
|
||||||
|
|
||||||
PocketDoc Labs
|
|
||||||
- [Dan's PersonalityEngine 13b LoRA](https://huggingface.co/PocketDoc/Dans-PersonalityEngine-13b-LoRA)
|
|
||||||
|
|
||||||
## Contributing 🤝
|
|
||||||
|
|
||||||
Please read the [contributing guide](./.github/CONTRIBUTING.md)
|
|
||||||
|
|
||||||
Bugs? Please check the [open issues](https://github.com/axolotl-ai-cloud/axolotl/issues/bug) else create a new Issue.
|
|
||||||
|
|
||||||
PRs are **greatly welcome**!
|
|
||||||
|
|
||||||
Please run the quickstart instructions followed by the below to setup env:
|
|
||||||
```bash
|
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
|
||||||
pre-commit install
|
|
||||||
|
|
||||||
# test
|
|
||||||
pytest tests/
|
|
||||||
|
|
||||||
# optional: run against all files
|
|
||||||
pre-commit run --all-files
|
|
||||||
```
|
|
||||||
|
|
||||||
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
|
|
||||||
|
|
||||||
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors">
|
|
||||||
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
|
|
||||||
</a>
|
|
||||||
|
|
||||||
## Sponsors 🤝❤
|
|
||||||
|
|
||||||
OpenAccess AI Collective is run by volunteer contributors such as [winglian](https://github.com/winglian),
|
|
||||||
[NanoCode012](https://github.com/NanoCode012), [tmm1](https://github.com/tmm1),
|
|
||||||
[mhenrichsen](https://github.com/mhenrichsen), [casper-hansen](https://github.com/casper-hansen),
|
|
||||||
[hamelsmu](https://github.com/hamelsmu) and many more who help us accelerate forward by fixing bugs, answering
|
|
||||||
community questions and implementing new features. Axolotl needs donations from sponsors for the compute needed to
|
|
||||||
run our unit & integration tests, troubleshooting community issues, and providing bounties. If you love axolotl,
|
|
||||||
consider sponsoring the project via [GitHub Sponsors](https://github.com/sponsors/OpenAccess-AI-Collective),
|
|
||||||
[Ko-fi](https://ko-fi.com/axolotl_ai) or reach out directly to
|
|
||||||
[wing@openaccessaicollective.org](mailto:wing@openaccessaicollective.org).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
#### 💎 Diamond Sponsors - [Contact directly](mailto:wing@openaccessaicollective.org)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
#### 🥇 Gold Sponsors - $5000/mo
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
#### 🥈 Silver Sponsors - $1000/mo
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
#### 🥉 Bronze Sponsors - $500/mo
|
|
||||||
|
|
||||||
- [JarvisLabs.ai](https://jarvislabs.ai)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
|||||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||||
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
|
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
|
||||||
ENV CUDA="{{ CUDA }}"
|
ENV CUDA="{{ CUDA }}"
|
||||||
ENV BNB_CUDA_VERSION="{{ CUDA }}"
|
|
||||||
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
|
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
|
||||||
ENV GITHUB_REF="{{ GITHUB_REF }}"
|
ENV GITHUB_REF="{{ GITHUB_REF }}"
|
||||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
||||||
@@ -32,9 +31,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py | sh
|
RUN python scripts/unsloth_install.py | sh
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ /workspace/axolotl/tests/
|
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||||
pytest -v --durations=10 -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
|
|
||||||
|
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||||
|
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
||||||
|
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/
|
||||||
|
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/integrations/
|
||||||
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
|||||||
ARG AXOLOTL_EXTRAS=""
|
ARG AXOLOTL_EXTRAS=""
|
||||||
ARG AXOLOTL_ARGS=""
|
ARG AXOLOTL_ARGS=""
|
||||||
ARG CUDA="118"
|
ARG CUDA="118"
|
||||||
ENV BNB_CUDA_VERSION=$CUDA
|
|
||||||
ARG PYTORCH_VERSION="2.1.2"
|
ARG PYTORCH_VERSION="2.1.2"
|
||||||
|
|
||||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||||
@@ -21,9 +20,9 @@ WORKDIR /workspace/axolotl
|
|||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py | sh
|
RUN python scripts/unsloth_install.py | sh
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
|
|||||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||||
|
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
|
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||||
&& wget \
|
&& wget \
|
||||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||||
&& mkdir /root/.conda \
|
&& mkdir /root/.conda \
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ ARG BASE_TAG=main
|
|||||||
FROM axolotlai/axolotl:$BASE_TAG
|
FROM axolotlai/axolotl:$BASE_TAG
|
||||||
|
|
||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ ARG BASE_TAG=main
|
|||||||
FROM axolotlai/axolotl:$BASE_TAG
|
FROM axolotlai/axolotl:$BASE_TAG
|
||||||
|
|
||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
|||||||
ARG AXOLOTL_EXTRAS=""
|
ARG AXOLOTL_EXTRAS=""
|
||||||
ARG AXOLOTL_ARGS=""
|
ARG AXOLOTL_ARGS=""
|
||||||
ARG CUDA="118"
|
ARG CUDA="118"
|
||||||
ENV BNB_CUDA_VERSION=$CUDA
|
|
||||||
ARG PYTORCH_VERSION="2.1.2"
|
ARG PYTORCH_VERSION="2.1.2"
|
||||||
ARG GITHUB_REF="main"
|
ARG GITHUB_REF="main"
|
||||||
|
|
||||||
@@ -25,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
|
|||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ export GPU_ARCHS="gfx90a"
|
|||||||
cd flash-attention
|
cd flash-attention
|
||||||
export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])')
|
export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])')
|
||||||
patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch
|
patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch
|
||||||
pip install .
|
pip install --no-build-isolation .
|
||||||
```
|
```
|
||||||
|
|
||||||
### 6. Install Axolotl
|
### 6. Install Axolotl
|
||||||
@@ -63,7 +63,7 @@ Clone and install Axolotl:
|
|||||||
git clone https://github.com/axolotl-ai-cloud/axolotl
|
git clone https://github.com/axolotl-ai-cloud/axolotl
|
||||||
cd axolotl
|
cd axolotl
|
||||||
pip install packaging ninja
|
pip install packaging ninja
|
||||||
pip install -e .
|
pip install --no-build-isolation -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
### 7. Apply xformers Workaround
|
### 7. Apply xformers Workaround
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/us
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install packaging
|
pip3 install packaging
|
||||||
pip3 install -e '.[flash-attn,deepspeed]'
|
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Remote Hosts
|
#### Remote Hosts
|
||||||
@@ -212,7 +212,7 @@ You will now be in the container. Next, perform an editable install of Axolotl:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install packaging
|
pip3 install packaging
|
||||||
pip3 install -e '.[flash-attn,deepspeed]'
|
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
|
|
||||||
### Attach To Container
|
### Attach To Container
|
||||||
|
|||||||
@@ -52,6 +52,26 @@ datasets:
|
|||||||
type: chat_template.argilla
|
type: chat_template.argilla
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
#### KTO
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
rl: kto
|
||||||
|
rl_beta: 0.5
|
||||||
|
kto_desirable_weight: 0.2
|
||||||
|
|
||||||
|
remove_unused_columns: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
|
||||||
|
type: llama3.ultra
|
||||||
|
split: train
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: true
|
||||||
|
```
|
||||||
|
|
||||||
#### Using local dataset files
|
#### Using local dataset files
|
||||||
```yaml
|
```yaml
|
||||||
datasets:
|
datasets:
|
||||||
|
|||||||
@@ -24,7 +24,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install axolotl[deepspeed]"
|
"!pip install --no-build-isolation axolotl[deepspeed]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
74
examples/llama-3/lora-1b.yml
Normal file
74
examples/llama-3/lora-1b.yml
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
base_model: NousResearch/Llama-3.2-1B
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
eval_sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
lora_target_modules:
|
||||||
|
- gate_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: "<|end_of_text|>"
|
||||||
75
examples/llama-3/qlora-1b-kto.yaml
Normal file
75
examples/llama-3/qlora-1b-kto.yaml
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
base_model: meta-llama/Llama-3.2-1B
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
rl: kto
|
||||||
|
rl_beta: 0.5
|
||||||
|
kto_desirable_weight: 0.2
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
|
||||||
|
type: llama3.ultra
|
||||||
|
split: train
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
|
remove_unused_columns: false
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false # not supported with kto
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 64
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 20
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: "<|end_of_text|>"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: meta-llama/Llama-3.2-1B
|
base_model: NousResearch/Llama-3.2-1B
|
||||||
|
|
||||||
load_in_8bit: false
|
load_in_8bit: false
|
||||||
load_in_4bit: true
|
load_in_4bit: true
|
||||||
@@ -22,7 +22,6 @@ pad_to_sequence_len: true
|
|||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
lora_target_modules:
|
lora_target_modules:
|
||||||
- gate_proj
|
- gate_proj
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
base_model: llava-hf/llava-1.5-7b-hf
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
chat_template: llava
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
field_messages: messages
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 8192
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
eager_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
base_model: mistral-community/pixtral-12b
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
chat_template: pixtral
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
field_messages: messages
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 8192
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet
|
|
||||||
eager_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
pad_token: <|end_of_text|>
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
base_model: Qwen/Qwen2-VL-7B-Instruct
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
chat_template: qwen2_vl
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
field_messages: messages
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 8192
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
eager_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
26
pyproject.toml
Normal file
26
pyproject.toml
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "axolotl"
|
||||||
|
dynamic = ["version", "dependencies", "optional-dependencies"]
|
||||||
|
description = "LLM Trainer"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
axolotl = "axolotl.cli.main:main"
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://axolotl-ai-cloud.github.io/axolotl/"
|
||||||
|
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
|
||||||
|
|
||||||
|
[tool.setuptools_scm]
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
|
||||||
|
include-package-data = true
|
||||||
|
|
||||||
|
[tool.setuptools.cmdclass]
|
||||||
|
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
|
||||||
@@ -1,22 +1,30 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
|
|
||||||
|
# START section of dependencies that don't install on Darwin/MacOS
|
||||||
|
bitsandbytes==0.45.0
|
||||||
|
triton>=2.3.0
|
||||||
|
mamba-ssm==1.2.0.post1
|
||||||
|
flash-attn==2.7.0.post2
|
||||||
|
xformers>=0.0.23.post1
|
||||||
|
autoawq==0.2.7.post3
|
||||||
|
liger-kernel==0.4.2
|
||||||
|
# END section
|
||||||
|
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.13.2
|
peft==0.14.0
|
||||||
transformers==4.46.3
|
transformers>=4.46.3
|
||||||
tokenizers>=0.20.1
|
tokenizers>=0.20.1
|
||||||
bitsandbytes==0.44.1
|
accelerate==1.2.0
|
||||||
accelerate==1.1.0
|
|
||||||
datasets==3.1.0
|
datasets==3.1.0
|
||||||
deepspeed==0.15.4
|
deepspeed==0.16.1
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
requests
|
requests
|
||||||
flash-attn==2.7.0.post2
|
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
xformers>=0.0.23.post1
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
colorama
|
colorama
|
||||||
@@ -31,18 +39,13 @@ art
|
|||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
tensorboard
|
tensorboard
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
autoawq==0.2.7.post2
|
|
||||||
triton>=2.3.0
|
|
||||||
liger-kernel==0.4.2
|
|
||||||
|
|
||||||
mamba-ssm==1.2.0.post1
|
|
||||||
|
|
||||||
# remote filesystems
|
# remote filesystems
|
||||||
s3fs>=2024.5.0
|
s3fs>=2024.5.0
|
||||||
gcsfs>=2024.5.0
|
gcsfs>=2024.5.0
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl==0.12.0
|
trl==0.12.1
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|
||||||
|
|||||||
@@ -16,11 +16,11 @@ if v < V("2.4.0"):
|
|||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
cce_spec = importlib.util.find_spec("cut_cross_entropy")
|
cce_spec = importlib.util.find_spec("cut_cross_entropy")
|
||||||
cce_spec_transformers = importlib.util.find_spec("cut_cross_entropy.transformers")
|
|
||||||
|
|
||||||
UNINSTALL_PREFIX = ""
|
UNINSTALL_PREFIX = ""
|
||||||
if cce_spec and not cce_spec_transformers:
|
if cce_spec:
|
||||||
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
|
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
|
||||||
|
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
|
||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
|
|||||||
@@ -13,5 +13,5 @@ cd /workspace
|
|||||||
rm -rf /workspace/axolotl
|
rm -rf /workspace/axolotl
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
pip install --no-deps -e .
|
pip install --no-build-isolation --no-deps -e .
|
||||||
```
|
```
|
||||||
|
|||||||
31
setup.py
31
setup.py
@@ -1,8 +1,10 @@
|
|||||||
"""setup.py for axolotl"""
|
"""setup.py for axolotl"""
|
||||||
|
import ast
|
||||||
|
import os
|
||||||
import platform
|
import platform
|
||||||
import re
|
import re
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
@@ -91,24 +93,39 @@ def parse_requirements():
|
|||||||
return _install_requires, _dependency_links
|
return _install_requires, _dependency_links
|
||||||
|
|
||||||
|
|
||||||
|
def get_package_version():
|
||||||
|
with open(
|
||||||
|
Path(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
/ "src"
|
||||||
|
/ "axolotl"
|
||||||
|
/ "__init__.py",
|
||||||
|
"r",
|
||||||
|
encoding="utf-8",
|
||||||
|
) as fin:
|
||||||
|
version_match = re.search(r"^__version__\s*=\s*(.*)$", fin.read(), re.MULTILINE)
|
||||||
|
version_ = ast.literal_eval(version_match.group(1))
|
||||||
|
return version_
|
||||||
|
|
||||||
|
|
||||||
install_requires, dependency_links = parse_requirements()
|
install_requires, dependency_links = parse_requirements()
|
||||||
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="axolotl",
|
version=get_package_version(),
|
||||||
version="0.5.2",
|
|
||||||
description="LLM Trainer",
|
|
||||||
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
|
|
||||||
package_dir={"": "src"},
|
package_dir={"": "src"},
|
||||||
packages=find_packages("src"),
|
packages=find_packages("src"),
|
||||||
install_requires=install_requires,
|
install_requires=install_requires,
|
||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
|
entry_points={
|
||||||
|
"console_scripts": [
|
||||||
|
"axolotl=axolotl.cli.main:main",
|
||||||
|
],
|
||||||
|
},
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn==2.7.0.post2",
|
"flash-attn==2.7.0.post2",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.15.4",
|
"deepspeed==0.16.1",
|
||||||
"deepspeed-kernels",
|
"deepspeed-kernels",
|
||||||
],
|
],
|
||||||
"mamba-ssm": [
|
"mamba-ssm": [
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
"""Axolotl - Train and fine-tune large language models"""
|
||||||
|
|
||||||
|
__version__ = "0.6.0"
|
||||||
|
|||||||
@@ -380,7 +380,7 @@ def choose_config(path: Path):
|
|||||||
|
|
||||||
if len(yaml_files) == 1:
|
if len(yaml_files) == 1:
|
||||||
print(f"Using default YAML file '{yaml_files[0]}'")
|
print(f"Using default YAML file '{yaml_files[0]}'")
|
||||||
return yaml_files[0]
|
return str(yaml_files[0])
|
||||||
|
|
||||||
print("Choose a YAML file:")
|
print("Choose a YAML file:")
|
||||||
for idx, file in enumerate(yaml_files):
|
for idx, file in enumerate(yaml_files):
|
||||||
@@ -391,7 +391,7 @@ def choose_config(path: Path):
|
|||||||
try:
|
try:
|
||||||
choice = int(input("Enter the number of your choice: "))
|
choice = int(input("Enter the number of your choice: "))
|
||||||
if 1 <= choice <= len(yaml_files):
|
if 1 <= choice <= len(yaml_files):
|
||||||
chosen_file = yaml_files[choice - 1]
|
chosen_file = str(yaml_files[choice - 1])
|
||||||
else:
|
else:
|
||||||
print("Invalid choice. Please choose a number from the list.")
|
print("Invalid choice. Please choose a number from the list.")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@@ -432,6 +432,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|||||||
except: # pylint: disable=bare-except # noqa: E722
|
except: # pylint: disable=bare-except # noqa: E722
|
||||||
gpu_version = None
|
gpu_version = None
|
||||||
|
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
|
||||||
cfg = validate_config(
|
cfg = validate_config(
|
||||||
cfg,
|
cfg,
|
||||||
capabilities={
|
capabilities={
|
||||||
@@ -440,12 +442,10 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|||||||
"compute_capability": gpu_version,
|
"compute_capability": gpu_version,
|
||||||
},
|
},
|
||||||
env_capabilities={
|
env_capabilities={
|
||||||
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
|
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
prepare_plugins(cfg)
|
|
||||||
|
|
||||||
prepare_optim_env(cfg)
|
prepare_optim_env(cfg)
|
||||||
|
|
||||||
prepare_opinionated_env(cfg)
|
prepare_opinionated_env(cfg)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
CLI to run inference on a trained model
|
CLI to run inference on a trained model
|
||||||
"""
|
"""
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
@@ -16,7 +17,7 @@ from axolotl.cli import (
|
|||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
|
def do_cli(config: Union[Path, str] = Path("examples/"), gradio=False, **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
parsed_cfg = load_cfg(config, inference=True, **kwargs)
|
parsed_cfg = load_cfg(config, inference=True, **kwargs)
|
||||||
|
|||||||
233
src/axolotl/cli/main.py
Normal file
233
src/axolotl/cli/main.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
"""CLI definition for various axolotl commands."""
|
||||||
|
# pylint: disable=redefined-outer-name
|
||||||
|
import subprocess # nosec B404
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
import axolotl
|
||||||
|
from axolotl.cli.utils import (
|
||||||
|
add_options_from_config,
|
||||||
|
add_options_from_dataclass,
|
||||||
|
build_command,
|
||||||
|
fetch_from_github,
|
||||||
|
)
|
||||||
|
from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||||
|
def cli():
|
||||||
|
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@add_options_from_dataclass(PreprocessCliArgs)
|
||||||
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
|
def preprocess(config: str, **kwargs):
|
||||||
|
"""Preprocess datasets before training."""
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
|
from axolotl.cli.preprocess import do_cli
|
||||||
|
|
||||||
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@click.option(
|
||||||
|
"--accelerate/--no-accelerate",
|
||||||
|
default=True,
|
||||||
|
help="Use accelerate launch for multi-GPU training",
|
||||||
|
)
|
||||||
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
|
def train(config: str, accelerate: bool, **kwargs):
|
||||||
|
"""Train or fine-tune a model."""
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
|
if accelerate:
|
||||||
|
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
|
||||||
|
if config:
|
||||||
|
base_cmd.append(config)
|
||||||
|
cmd = build_command(base_cmd, kwargs)
|
||||||
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
|
else:
|
||||||
|
from axolotl.cli.train import do_cli
|
||||||
|
|
||||||
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@click.option(
|
||||||
|
"--accelerate/--no-accelerate",
|
||||||
|
default=True,
|
||||||
|
help="Use accelerate launch for multi-GPU inference",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--lora-model-dir",
|
||||||
|
type=click.Path(exists=True, path_type=str),
|
||||||
|
help="Directory containing LoRA model",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--base-model",
|
||||||
|
type=click.Path(exists=True, path_type=str),
|
||||||
|
help="Path to base model for non-LoRA models",
|
||||||
|
)
|
||||||
|
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
|
||||||
|
@click.option("--load-in-8bit", is_flag=True, help="Load model in 8-bit mode")
|
||||||
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
|
def inference(
|
||||||
|
config: str,
|
||||||
|
accelerate: bool,
|
||||||
|
lora_model_dir: Optional[str] = None,
|
||||||
|
base_model: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Run inference with a trained model."""
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
del kwargs["inference"] # interferes with inference.do_cli
|
||||||
|
|
||||||
|
if lora_model_dir:
|
||||||
|
kwargs["lora_model_dir"] = lora_model_dir
|
||||||
|
if base_model:
|
||||||
|
kwargs["output_dir"] = base_model
|
||||||
|
|
||||||
|
if accelerate:
|
||||||
|
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
|
||||||
|
if config:
|
||||||
|
base_cmd.append(config)
|
||||||
|
cmd = build_command(base_cmd, kwargs)
|
||||||
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
|
else:
|
||||||
|
from axolotl.cli.inference import do_cli
|
||||||
|
|
||||||
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@click.option(
|
||||||
|
"--accelerate/--no-accelerate",
|
||||||
|
default=False,
|
||||||
|
help="Use accelerate launch for multi-GPU operations",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--model-dir",
|
||||||
|
type=click.Path(exists=True, path_type=str),
|
||||||
|
help="Directory containing model weights to shard",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--save-dir",
|
||||||
|
type=click.Path(path_type=str),
|
||||||
|
help="Directory to save sharded weights",
|
||||||
|
)
|
||||||
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
|
def shard(config: str, accelerate: bool, **kwargs):
|
||||||
|
"""Shard model weights."""
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
|
if accelerate:
|
||||||
|
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.shard"]
|
||||||
|
if config:
|
||||||
|
base_cmd.append(config)
|
||||||
|
cmd = build_command(base_cmd, kwargs)
|
||||||
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
|
else:
|
||||||
|
from axolotl.cli.shard import do_cli
|
||||||
|
|
||||||
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@click.option(
|
||||||
|
"--accelerate/--no-accelerate",
|
||||||
|
default=True,
|
||||||
|
help="Use accelerate launch for weight merging",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--model-dir",
|
||||||
|
type=click.Path(exists=True, path_type=str),
|
||||||
|
help="Directory containing sharded weights",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--save-path", type=click.Path(path_type=str), help="Path to save merged weights"
|
||||||
|
)
|
||||||
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
|
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs):
|
||||||
|
"""Merge sharded FSDP model weights."""
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
|
if accelerate:
|
||||||
|
base_cmd = [
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"-m",
|
||||||
|
"axolotl.cli.merge_sharded_fsdp_weights",
|
||||||
|
]
|
||||||
|
if config:
|
||||||
|
base_cmd.append(config)
|
||||||
|
cmd = build_command(base_cmd, kwargs)
|
||||||
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
|
else:
|
||||||
|
from axolotl.cli.merge_sharded_fsdp_weights import do_cli
|
||||||
|
|
||||||
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@click.option(
|
||||||
|
"--lora-model-dir",
|
||||||
|
type=click.Path(exists=True, path_type=str),
|
||||||
|
help="Directory containing the LoRA model to merge",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--output-dir",
|
||||||
|
type=click.Path(path_type=str),
|
||||||
|
help="Directory to save the merged model",
|
||||||
|
)
|
||||||
|
def merge_lora(
|
||||||
|
config: str,
|
||||||
|
lora_model_dir: Optional[str] = None,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Merge a trained LoRA into a base model"""
|
||||||
|
kwargs = {}
|
||||||
|
if lora_model_dir:
|
||||||
|
kwargs["lora_model_dir"] = lora_model_dir
|
||||||
|
if output_dir:
|
||||||
|
kwargs["output_dir"] = output_dir
|
||||||
|
|
||||||
|
from axolotl.cli.merge_lora import do_cli
|
||||||
|
|
||||||
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
|
||||||
|
@click.option("--dest", help="Destination directory")
|
||||||
|
def fetch(directory: str, dest: Optional[str]):
|
||||||
|
"""
|
||||||
|
Fetch example configs or other resources.
|
||||||
|
|
||||||
|
Available directories:
|
||||||
|
- examples: Example configuration files
|
||||||
|
- deepspeed_configs: DeepSpeed configuration files
|
||||||
|
"""
|
||||||
|
fetch_from_github(f"{directory}/", dest)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
cli()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
CLI to run merge a trained LoRA into a base model
|
CLI to run merge a trained LoRA into a base model
|
||||||
"""
|
"""
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
@@ -11,7 +12,7 @@ from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
|
|||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
|
|||||||
@@ -177,7 +177,7 @@ def merge_fsdp_weights(
|
|||||||
state.wait_for_everyone()
|
state.wait_for_everyone()
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
|
|||||||
218
src/axolotl/cli/utils.py
Normal file
218
src/axolotl/cli/utils.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
"""Utility methods for axoltl CLI."""
|
||||||
|
import concurrent.futures
|
||||||
|
import dataclasses
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from types import NoneType
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin
|
||||||
|
|
||||||
|
import click
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.cli.utils")
|
||||||
|
|
||||||
|
|
||||||
|
def add_options_from_dataclass(config_class: Type[Any]):
|
||||||
|
"""Create Click options from the fields of a dataclass."""
|
||||||
|
|
||||||
|
def decorator(function):
|
||||||
|
# Process dataclass fields in reverse order for correct option ordering
|
||||||
|
for field in reversed(dataclasses.fields(config_class)):
|
||||||
|
field_type = field.type
|
||||||
|
|
||||||
|
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||||
|
field_type = next(
|
||||||
|
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||||
|
)
|
||||||
|
|
||||||
|
if field_type == bool:
|
||||||
|
field_name = field.name.replace("_", "-")
|
||||||
|
option_name = f"--{field_name}/--no-{field_name}"
|
||||||
|
function = click.option(
|
||||||
|
option_name,
|
||||||
|
default=field.default,
|
||||||
|
help=field.metadata.get("description"),
|
||||||
|
)(function)
|
||||||
|
else:
|
||||||
|
option_name = f"--{field.name.replace('_', '-')}"
|
||||||
|
function = click.option(
|
||||||
|
option_name,
|
||||||
|
type=field_type,
|
||||||
|
default=field.default,
|
||||||
|
help=field.metadata.get("description"),
|
||||||
|
)(function)
|
||||||
|
return function
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def add_options_from_config(config_class: Type[BaseModel]):
|
||||||
|
"""Create Click options from the fields of a Pydantic model."""
|
||||||
|
|
||||||
|
def decorator(function):
|
||||||
|
# Process model fields in reverse order for correct option ordering
|
||||||
|
for name, field in reversed(config_class.model_fields.items()):
|
||||||
|
if field.annotation == bool:
|
||||||
|
field_name = name.replace("_", "-")
|
||||||
|
option_name = f"--{field_name}/--no-{field_name}"
|
||||||
|
function = click.option(
|
||||||
|
option_name, default=None, help=field.description
|
||||||
|
)(function)
|
||||||
|
else:
|
||||||
|
option_name = f"--{name.replace('_', '-')}"
|
||||||
|
function = click.option(
|
||||||
|
option_name, default=None, help=field.description
|
||||||
|
)(function)
|
||||||
|
return function
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
|
||||||
|
"""Build command list from base command and options."""
|
||||||
|
cmd = base_cmd.copy()
|
||||||
|
|
||||||
|
for key, value in options.items():
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
key = key.replace("_", "-")
|
||||||
|
|
||||||
|
if isinstance(value, bool):
|
||||||
|
if value:
|
||||||
|
cmd.append(f"--{key}")
|
||||||
|
else:
|
||||||
|
cmd.extend([f"--{key}", str(value)])
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
|
def download_file(
|
||||||
|
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Download a single file and return its processing status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_info: Tuple of (file_path, remote_sha)
|
||||||
|
raw_base_url: Base URL for raw GitHub content
|
||||||
|
dest_path: Local destination directory
|
||||||
|
dir_prefix: Directory prefix to filter files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'
|
||||||
|
"""
|
||||||
|
file_path, remote_sha = file_info
|
||||||
|
raw_url = f"{raw_base_url}/{file_path}"
|
||||||
|
dest_file = dest_path / file_path.split(dir_prefix)[-1]
|
||||||
|
|
||||||
|
# Check if file exists and needs updating
|
||||||
|
if dest_file.exists():
|
||||||
|
with open(dest_file, "rb") as file:
|
||||||
|
content = file.read()
|
||||||
|
# Calculate git blob SHA
|
||||||
|
blob = b"blob " + str(len(content)).encode() + b"\0" + content
|
||||||
|
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
|
||||||
|
|
||||||
|
if local_sha == remote_sha:
|
||||||
|
print(f"Skipping {file_path} (unchanged)")
|
||||||
|
return file_path, "unchanged"
|
||||||
|
|
||||||
|
print(f"Updating {file_path}")
|
||||||
|
status = "new"
|
||||||
|
else:
|
||||||
|
print(f"Downloading {file_path}")
|
||||||
|
status = "new"
|
||||||
|
|
||||||
|
# Create directories if needed
|
||||||
|
dest_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Download and save file
|
||||||
|
try:
|
||||||
|
response = requests.get(raw_url, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
with open(dest_file, "wb") as file:
|
||||||
|
file.write(response.content)
|
||||||
|
|
||||||
|
return file_path, status
|
||||||
|
except (requests.RequestException, IOError) as request_error:
|
||||||
|
print(f"Error downloading {file_path}: {str(request_error)}")
|
||||||
|
return file_path, "error"
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_from_github(
|
||||||
|
dir_prefix: str, dest_dir: Optional[str] = None, max_workers: int = 5
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Sync files from a specific directory in the GitHub repository.
|
||||||
|
Only downloads files that don't exist locally or have changed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dir_prefix: Directory prefix to filter files (e.g., 'examples/', 'deepspeed_configs/')
|
||||||
|
dest_dir: Local destination directory
|
||||||
|
max_workers: Maximum number of concurrent downloads
|
||||||
|
"""
|
||||||
|
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
|
||||||
|
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
|
||||||
|
|
||||||
|
# Get repository tree with timeout
|
||||||
|
response = requests.get(api_url, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
tree = json.loads(response.text)
|
||||||
|
|
||||||
|
# Filter for files and get their SHA
|
||||||
|
files = {
|
||||||
|
item["path"]: item["sha"]
|
||||||
|
for item in tree["tree"]
|
||||||
|
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
if not files:
|
||||||
|
raise click.ClickException(f"No files found in {dir_prefix}")
|
||||||
|
|
||||||
|
# Default destination directory is the last part of dir_prefix
|
||||||
|
default_dest = Path(dir_prefix.rstrip("/"))
|
||||||
|
dest_path = Path(dest_dir) if dest_dir else default_dest
|
||||||
|
|
||||||
|
# Keep track of processed files for summary
|
||||||
|
files_processed: Dict[str, List[str]] = {
|
||||||
|
"new": [],
|
||||||
|
"updated": [],
|
||||||
|
"unchanged": [],
|
||||||
|
"error": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process files in parallel using ThreadPoolExecutor
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
future_to_file = {
|
||||||
|
executor.submit(
|
||||||
|
download_file,
|
||||||
|
(file_path, remote_sha),
|
||||||
|
raw_base_url,
|
||||||
|
dest_path,
|
||||||
|
dir_prefix,
|
||||||
|
): file_path
|
||||||
|
for file_path, remote_sha in files.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process completed tasks as they finish
|
||||||
|
for future in concurrent.futures.as_completed(future_to_file):
|
||||||
|
file_path = future_to_file[future]
|
||||||
|
try:
|
||||||
|
file_path, status = future.result()
|
||||||
|
files_processed[status].append(file_path)
|
||||||
|
except (requests.RequestException, IOError) as request_error:
|
||||||
|
print(f"Error processing {file_path}: {str(request_error)}")
|
||||||
|
files_processed["error"].append(file_path)
|
||||||
|
|
||||||
|
# Log summary
|
||||||
|
LOG.info("\nSync Summary:")
|
||||||
|
LOG.info(f"New files: {len(files_processed['new'])}")
|
||||||
|
LOG.info(f"Updated files: {len(files_processed['updated'])}")
|
||||||
|
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
|
||||||
|
if files_processed["error"]:
|
||||||
|
LOG.info(f"Failed files: {len(files_processed['error'])}")
|
||||||
@@ -3,36 +3,88 @@ helper functions for fixing the embeddings/tokenizer
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||||
|
# GNU LESSER GENERAL PUBLIC LICENSE
|
||||||
|
# Version 3, 29 June 2007
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||||
# you may not use this file except in compliance with the License.
|
# Everyone is permitted to copy and distribute verbatim copies
|
||||||
# You may obtain a copy of the License at
|
# of this license document, but changing it is not allowed.
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import itertools
|
import itertools
|
||||||
|
import logging
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.core.tokenizer_utils")
|
||||||
|
|
||||||
@torch.inference_mode
|
|
||||||
def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
@torch.inference_mode()
|
||||||
|
def fix_untrained_tokens( # pylint: disable=too-many-return-statements
|
||||||
|
model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Many of the newer models have reserved tokens that are not trained.
|
Llama-3 for eg has untrained vectors in the base model.
|
||||||
|
These include <|eot_id|>, <|start_header_id|>, <|end_header_id|>
|
||||||
|
We reset them to the mean of the rest of the tokens
|
||||||
"""
|
"""
|
||||||
|
# Code licensed under LGPL
|
||||||
embedding_matrix = model.get_input_embeddings().weight
|
embedding_matrix = model.get_input_embeddings().weight
|
||||||
lm_head_matrix = model.get_output_embeddings().weight
|
lm_head_matrix = model.get_output_embeddings().weight
|
||||||
|
chat_template = getattr(tokenizer, "chat_template", None)
|
||||||
|
tokenizer = tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer
|
||||||
|
|
||||||
|
# Ignore some model checks for now
|
||||||
|
if not ignored_tokenizer_names:
|
||||||
|
ignored_tokenizer_names = []
|
||||||
|
if (
|
||||||
|
model.config._name_or_path # pylint: disable=protected-access
|
||||||
|
in ignored_tokenizer_names
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Sometimes the sizes can be different like in vision models
|
||||||
|
# Ie <image> is in input, but not in output
|
||||||
|
min_size = min(embedding_matrix.shape[1], lm_head_matrix.shape[1])
|
||||||
|
embedding_matrix = embedding_matrix[:, :min_size]
|
||||||
|
lm_head_matrix = lm_head_matrix[:, :min_size]
|
||||||
|
|
||||||
# Get untrained tokens
|
# Get untrained tokens
|
||||||
indicator_untrained = torch.amax(embedding_matrix, axis=1) <= eps
|
indicator_untrained1 = torch.amax(embedding_matrix, axis=1) <= eps
|
||||||
|
# Check lm_head as well
|
||||||
|
|
||||||
|
# Does NOT work for Llama 3.1!!
|
||||||
|
indicator_untrained2 = torch.amax(lm_head_matrix, axis=1) <= eps
|
||||||
|
|
||||||
|
# We instead check for repeated vectors
|
||||||
|
lm_head_where = torch.where(indicator_untrained1)[0]
|
||||||
|
lm_head_bad = lm_head_matrix[lm_head_where]
|
||||||
|
lm_head_bad = lm_head_bad.cpu().float().numpy().round(3)
|
||||||
|
counter = Counter()
|
||||||
|
for row in lm_head_bad:
|
||||||
|
counter[hash(row.data.tobytes())] += 1
|
||||||
|
counter = Counter({k: c for k, c in counter.items() if c >= 2})
|
||||||
|
|
||||||
|
lm_head_where = lm_head_where.cpu().numpy()
|
||||||
|
final_bad_lm_head = []
|
||||||
|
for j, row in enumerate(lm_head_bad):
|
||||||
|
if hash(row.data.tobytes()) in counter:
|
||||||
|
final_bad_lm_head.append(lm_head_where[j])
|
||||||
|
indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained2)
|
||||||
|
indicator_untrained2[final_bad_lm_head] = True
|
||||||
|
|
||||||
|
# Combine both checks
|
||||||
|
indicator_untrained = indicator_untrained1 & indicator_untrained2
|
||||||
|
|
||||||
|
# Remove pad token possibility
|
||||||
|
if hasattr(tokenizer, "pad_token_id"):
|
||||||
|
pad_token_id = tokenizer.pad_token_id
|
||||||
|
if pad_token_id is not None and pad_token_id < indicator_untrained.shape[0]:
|
||||||
|
indicator_untrained[pad_token_id] = False
|
||||||
|
|
||||||
where_untrained = torch.where(indicator_untrained)[0]
|
where_untrained = torch.where(indicator_untrained)[0]
|
||||||
n_untrained = where_untrained.shape[0]
|
n_untrained = where_untrained.shape[0]
|
||||||
n_trained = embedding_matrix.shape[0] - n_untrained
|
n_trained = embedding_matrix.shape[0] - n_untrained
|
||||||
@@ -40,10 +92,9 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
|||||||
# Get set and actual tokens
|
# Get set and actual tokens
|
||||||
where_untrained = where_untrained.tolist()
|
where_untrained = where_untrained.tolist()
|
||||||
if len(where_untrained) == 0:
|
if len(where_untrained) == 0:
|
||||||
return False
|
return
|
||||||
|
|
||||||
# Remove untrained indices where it's longer
|
# Remove untrained indices where it's longer
|
||||||
|
|
||||||
where_untrained_set = frozenset(where_untrained)
|
where_untrained_set = frozenset(where_untrained)
|
||||||
actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained)
|
actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained)
|
||||||
# Remove None items in actual_bad_tokens
|
# Remove None items in actual_bad_tokens
|
||||||
@@ -53,10 +104,14 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
|||||||
if_bad_first = False
|
if_bad_first = False
|
||||||
if_bad_second = False
|
if_bad_second = False
|
||||||
# Check tokenizer's chat template for any untrained tokens
|
# Check tokenizer's chat template for any untrained tokens
|
||||||
chat_template = getattr(tokenizer, "chat_template", None)
|
|
||||||
if chat_template is not None:
|
if chat_template is not None:
|
||||||
if_bad_first = any(x in chat_template for x in actual_bad_tokens)
|
if_bad_first = any(x in chat_template for x in actual_bad_tokens)
|
||||||
|
|
||||||
|
if isinstance(train_dataset, datasets.IterableDataset):
|
||||||
|
# Skip the check, since the code below assumes
|
||||||
|
# an indexable dataset
|
||||||
|
return
|
||||||
|
|
||||||
# Check the first 250, last 250 input_ids
|
# Check the first 250, last 250 input_ids
|
||||||
size_dataset = len(train_dataset)
|
size_dataset = len(train_dataset)
|
||||||
size = min(size_dataset, 250)
|
size = min(size_dataset, 250)
|
||||||
@@ -83,7 +138,69 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
|||||||
|
|
||||||
# Check if bad tokens exists!
|
# Check if bad tokens exists!
|
||||||
if not if_bad_first and not if_bad_second:
|
if not if_bad_first and not if_bad_second:
|
||||||
return False
|
return
|
||||||
|
|
||||||
|
# Check if lm_head / embed_token are trainable!
|
||||||
|
bad_not_trainable = False
|
||||||
|
if not embedding_matrix.requires_grad:
|
||||||
|
bad_not_trainable = True
|
||||||
|
if not lm_head_matrix.requires_grad:
|
||||||
|
bad_not_trainable = True
|
||||||
|
|
||||||
|
if bad_not_trainable: # pylint: disable=too-many-nested-blocks
|
||||||
|
final_bad_items = []
|
||||||
|
|
||||||
|
# Re-check the first 250, last 250 input_ids
|
||||||
|
size_dataset = len(train_dataset)
|
||||||
|
size = min(size_dataset, 250)
|
||||||
|
for j in range(size):
|
||||||
|
input_ids = train_dataset[j]
|
||||||
|
if "input_ids" in input_ids:
|
||||||
|
input_ids = input_ids["input_ids"]
|
||||||
|
for item in input_ids:
|
||||||
|
if item in where_untrained_set:
|
||||||
|
final_bad_items.append(item)
|
||||||
|
|
||||||
|
# Re-check last 250
|
||||||
|
left = max(size_dataset - 250, 0)
|
||||||
|
for j in range(left, size_dataset):
|
||||||
|
input_ids = train_dataset[j]
|
||||||
|
if "input_ids" in input_ids:
|
||||||
|
input_ids = input_ids["input_ids"]
|
||||||
|
for item in input_ids:
|
||||||
|
if item in where_untrained_set:
|
||||||
|
final_bad_items.append(item)
|
||||||
|
|
||||||
|
# If no bad tokens, possibly chat template itself has issues?
|
||||||
|
if len(final_bad_items) == 0:
|
||||||
|
# Recheck 2000 and last 2000 items
|
||||||
|
size_dataset = len(train_dataset)
|
||||||
|
size = min(size_dataset, 2000)
|
||||||
|
for j in range(size):
|
||||||
|
input_ids = train_dataset[j]
|
||||||
|
if "input_ids" in input_ids:
|
||||||
|
input_ids = input_ids["input_ids"]
|
||||||
|
for item in input_ids:
|
||||||
|
if item in where_untrained_set:
|
||||||
|
final_bad_items.append(item)
|
||||||
|
|
||||||
|
# Re-check last 2000
|
||||||
|
left = max(size_dataset - 2000, 0)
|
||||||
|
for j in range(left, size_dataset):
|
||||||
|
input_ids = train_dataset[j]
|
||||||
|
if "input_ids" in input_ids:
|
||||||
|
input_ids = input_ids["input_ids"]
|
||||||
|
for item in input_ids:
|
||||||
|
if item in where_untrained_set:
|
||||||
|
final_bad_items.append(item)
|
||||||
|
|
||||||
|
# Most likely false signal!
|
||||||
|
if len(final_bad_items) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Untrained tokens of [{list(set(final_bad_items))}] found, but embed_tokens & lm_head not trainable, causing NaNs. "
|
||||||
|
)
|
||||||
|
|
||||||
# Count all the possible bad tokens
|
# Count all the possible bad tokens
|
||||||
final_counts = np.zeros(
|
final_counts = np.zeros(
|
||||||
@@ -97,6 +214,23 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
|||||||
|
|
||||||
train_dataset.map(mapping, batched=True, desc="Counting untrained tokens")
|
train_dataset.map(mapping, batched=True, desc="Counting untrained tokens")
|
||||||
|
|
||||||
|
# Get counts for untrained tokens
|
||||||
|
counts_untrained = final_counts[where_untrained]
|
||||||
|
# Identify untrained tokens seen in train_dataset
|
||||||
|
indices_seen_in_train = np.where(counts_untrained > 0)[0]
|
||||||
|
tokens_to_update = [where_untrained[i] for i in indices_seen_in_train]
|
||||||
|
|
||||||
|
if len(tokens_to_update) == 0:
|
||||||
|
LOG.info(
|
||||||
|
"No untrained tokens found in train_dataset. No embeddings were modified."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Log the token IDs that are being rescaled
|
||||||
|
LOG.info(
|
||||||
|
f"Rescaling embeddings for tokens seen in train_dataset: {tokens_to_update}"
|
||||||
|
)
|
||||||
|
|
||||||
# Get sum of all items
|
# Get sum of all items
|
||||||
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
|
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
|
||||||
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
|
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
|
||||||
@@ -113,38 +247,26 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
|||||||
mean_embedding = sum_embedding / n_trained
|
mean_embedding = sum_embedding / n_trained
|
||||||
mean_lm_head = sum_lm_head / n_trained
|
mean_lm_head = sum_lm_head / n_trained
|
||||||
|
|
||||||
# Scale each to be equal to 1/max_frequency. Also set some to 0 if none seen
|
# Compute scaling for tokens to update
|
||||||
scaling = final_counts[where_untrained] / max(final_counts.max(), 1)
|
scaling = counts_untrained[indices_seen_in_train] / max(final_counts.max(), 1)
|
||||||
scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1)
|
scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1)
|
||||||
mean_embedding = (
|
|
||||||
mean_embedding.repeat(
|
|
||||||
(
|
|
||||||
n_untrained,
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
* scaling
|
|
||||||
)
|
|
||||||
mean_lm_head = (
|
|
||||||
mean_lm_head.repeat(
|
|
||||||
(
|
|
||||||
n_untrained,
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
* scaling
|
|
||||||
)
|
|
||||||
where_null = scaling.ravel() == 0
|
|
||||||
mean_embedding[where_null] = 0
|
|
||||||
mean_lm_head[where_null] = 0
|
|
||||||
|
|
||||||
# Set them to the mean
|
# Prepare mean embeddings for tokens to update
|
||||||
embedding_matrix[where_untrained] = mean_embedding.to(embedding_matrix.dtype)
|
mean_embedding_repeated = (
|
||||||
lm_head_matrix[where_untrained] = mean_lm_head.to(lm_head_matrix.dtype)
|
mean_embedding.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
|
||||||
|
)
|
||||||
|
mean_lm_head_repeated = (
|
||||||
|
mean_lm_head.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update embeddings only for tokens seen in train_dataset
|
||||||
|
embedding_matrix[tokens_to_update] = mean_embedding_repeated.to(
|
||||||
|
embedding_matrix.dtype
|
||||||
|
)
|
||||||
|
lm_head_matrix[tokens_to_update] = mean_lm_head_repeated.to(lm_head_matrix.dtype)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
return
|
||||||
return True
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from packaging import version
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
@@ -957,13 +958,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def log(self, logs: Dict[str, float]) -> None:
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Log `logs` on the various objects watching training, including stored metrics.
|
Log `logs` on the various objects watching training, including stored metrics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logs (`Dict[str, float]`):
|
logs (`Dict[str, float]`):
|
||||||
The values to log.
|
The values to log.
|
||||||
|
start_time (`Optional[float]`):
|
||||||
|
The start of training.
|
||||||
"""
|
"""
|
||||||
# logs either has 'loss' or 'eval_loss'
|
# logs either has 'loss' or 'eval_loss'
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
@@ -971,7 +974,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
for key, metrics in self._stored_metrics[train_eval].items():
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
logs[key] = torch.tensor(metrics).mean().item()
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
del self._stored_metrics[train_eval]
|
del self._stored_metrics[train_eval]
|
||||||
return super().log(logs)
|
|
||||||
|
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||||
|
try:
|
||||||
|
return super().log(logs, start_time)
|
||||||
|
except TypeError:
|
||||||
|
return super().log(logs) # transformers<=4.46
|
||||||
|
return super().log(logs) # transformers<=4.46
|
||||||
|
|
||||||
def store_metrics(
|
def store_metrics(
|
||||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||||
@@ -987,6 +996,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
|
def _evaluate(self, *args, **kwargs):
|
||||||
|
metrics = super()._evaluate(*args, **kwargs)
|
||||||
|
|
||||||
|
# cleanup memory after evals
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1155,6 +1173,22 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
# logs either has 'loss' or 'eval_loss'
|
||||||
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
# Add averaged stored metrics to logs
|
||||||
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
|
del self._stored_metrics[train_eval]
|
||||||
|
|
||||||
|
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||||
|
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
|
logs, start_time
|
||||||
|
)
|
||||||
|
# transformers<=4.46
|
||||||
|
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1163,6 +1197,22 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "orpo"]
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
# logs either has 'loss' or 'eval_loss'
|
||||||
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
# Add averaged stored metrics to logs
|
||||||
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
|
del self._stored_metrics[train_eval]
|
||||||
|
|
||||||
|
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||||
|
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
|
logs, start_time
|
||||||
|
)
|
||||||
|
# transformers<=4.46
|
||||||
|
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1171,6 +1221,49 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "kto"]
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
# logs either has 'loss' or 'eval_loss'
|
||||||
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
# train metrics should have no prefix, eval should have 'eval_'
|
||||||
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
|
# accumulate average metrics from sums and lengths
|
||||||
|
for split in ["chosen", "rejected"]:
|
||||||
|
if f"count/{split}" in self._stored_metrics[train_eval]:
|
||||||
|
count_sum = (
|
||||||
|
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
|
||||||
|
.sum()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
for metric in ["rewards", "logps", "logits"]:
|
||||||
|
logs[f"{prefix}{metric}/{split}"] = (
|
||||||
|
torch.Tensor(
|
||||||
|
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
||||||
|
)
|
||||||
|
.sum()
|
||||||
|
.item()
|
||||||
|
/ count_sum
|
||||||
|
)
|
||||||
|
# delete obsolete metric
|
||||||
|
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
||||||
|
del self._stored_metrics[train_eval][f"count/{split}"]
|
||||||
|
# calculate reward margin
|
||||||
|
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
|
||||||
|
logs[f"{prefix}rewards/margins"] = (
|
||||||
|
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
||||||
|
)
|
||||||
|
# Add averaged stored metrics to logs
|
||||||
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
|
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
||||||
|
del self._stored_metrics[train_eval]
|
||||||
|
|
||||||
|
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||||
|
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
|
logs, start_time
|
||||||
|
)
|
||||||
|
# transformers<=4.46
|
||||||
|
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1179,6 +1272,22 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "cpo"]
|
tag_names = ["axolotl", "cpo"]
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
# logs either has 'loss' or 'eval_loss'
|
||||||
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
# Add averaged stored metrics to logs
|
||||||
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
|
del self._stored_metrics[train_eval]
|
||||||
|
|
||||||
|
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||||
|
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
|
logs, start_time
|
||||||
|
)
|
||||||
|
# transformers<=4.46
|
||||||
|
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||||
|
|
||||||
|
|
||||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1187,6 +1296,15 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "reward"]
|
tag_names = ["axolotl", "reward"]
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||||
|
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
|
logs, start_time
|
||||||
|
)
|
||||||
|
# transformers<=4.46
|
||||||
|
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""
|
"""
|
||||||
@@ -1259,8 +1377,6 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
if self.cfg.use_mlflow and is_mlflow_available():
|
if self.cfg.use_mlflow and is_mlflow_available():
|
||||||
from transformers.integrations.integration_utils import MLflowCallback
|
|
||||||
|
|
||||||
from axolotl.utils.callbacks.mlflow_ import (
|
from axolotl.utils.callbacks.mlflow_ import (
|
||||||
SaveAxolotlConfigtoMlflowCallback,
|
SaveAxolotlConfigtoMlflowCallback,
|
||||||
)
|
)
|
||||||
@@ -1268,7 +1384,6 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
callbacks.extend(
|
callbacks.extend(
|
||||||
[
|
[
|
||||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
|
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
|
||||||
MLflowCallback,
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
if self.cfg.use_comet and is_comet_available():
|
if self.cfg.use_comet and is_comet_available():
|
||||||
@@ -1897,7 +2012,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
collator = MultiModalChatDataCollator
|
collator = MultiModalChatDataCollator
|
||||||
kwargs["processor"] = self.processor
|
kwargs["processor"] = self.processor
|
||||||
kwargs["chat_template"] = training_args.chat_template
|
kwargs["chat_template"] = training_args.chat_template
|
||||||
kwargs["chat_template_type"] = self.cfg.chat_template
|
|
||||||
else:
|
else:
|
||||||
collator = DataCollatorForSeq2Seq
|
collator = DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class CutCrossEntropyArgs(BaseModel):
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_dtype_is_half(cls, data):
|
def check_dtype_is_half(cls, data):
|
||||||
if not (data.get("bf16") or data.get("fp16")):
|
if data.get("cut_cross_entropy") and not (data.get("bf16") or data.get("fp16")):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cut Cross Entropy requires fp16/bf16 training for backward pass. "
|
"Cut Cross Entropy requires fp16/bf16 training for backward pass. "
|
||||||
"Please set `bf16` or `fp16` to `True`."
|
"Please set `bf16` or `fp16` to `True`."
|
||||||
|
|||||||
0
src/axolotl/monkeypatch/models/__init__.py
Normal file
0
src/axolotl/monkeypatch/models/__init__.py
Normal file
170
src/axolotl/monkeypatch/models/llama/modeling_llama.py
Normal file
170
src/axolotl/monkeypatch/models/llama/modeling_llama.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
import contextlib
|
||||||
|
import inspect
|
||||||
|
import types
|
||||||
|
|
||||||
|
from torchtune.training import OffloadActivations
|
||||||
|
from transformers import LlamaConfig, LlamaForCausalLM
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||||
|
|
||||||
|
HF_MODEL_OUTPUTS = """
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
PATCHED_HF_MODEL_OUTPUTS = """
|
||||||
|
with self.act_offloading_ctx_manager:
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
LCE_MODEL_OUTPUTS = """
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
PATCHED_LCE_OUTPUTS = """
|
||||||
|
with self.act_offloading_ctx_manager:
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
HF_GA_FORWARD_1 = """
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
PATCHED_HF_GA_FORWARD_1 = """
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
|
||||||
|
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
HF_GA_FORWARD_2 = """
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
PATCHED_HF_GA_FORWARD_2 = """
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
||||||
|
act_offloading_ctx_manager = contextlib.nullcontext()
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_forward(cls):
|
||||||
|
forward_source = inspect.getsource(LlamaForCausalLM.forward)
|
||||||
|
forward_source, _ = detab_code(forward_source)
|
||||||
|
cls.forward = types.MethodType(
|
||||||
|
compile(forward_source, "<forward>", "exec"), cls
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enable_act_offloading(cls):
|
||||||
|
forward_source = inspect.getsource(cls.forward)
|
||||||
|
forward_source = forward_source.replace(
|
||||||
|
HF_MODEL_OUTPUTS, PATCHED_HF_MODEL_OUTPUTS
|
||||||
|
)
|
||||||
|
forward_source, _ = detab_code(forward_source)
|
||||||
|
# replace forward method with patched version
|
||||||
|
cls.forward = types.MethodType(
|
||||||
|
compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), cls
|
||||||
|
)
|
||||||
|
cls.act_offloading_ctx_manager = OffloadActivations()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enable_liger_fce(cls, enable_act_offloading=True):
|
||||||
|
from liger_kernel.transformers.model.llama import (
|
||||||
|
lce_forward as llama_lce_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
if enable_act_offloading:
|
||||||
|
lce_source = inspect.getsource(llama_lce_forward)
|
||||||
|
lce_source = lce_source.replace(LCE_MODEL_OUTPUTS, PATCHED_LCE_OUTPUTS)
|
||||||
|
# replace forward method with patched version
|
||||||
|
cls.forward = types.MethodType(
|
||||||
|
compile(lce_source, "<llama_lce_forward_w_act_offloading>", "exec"),
|
||||||
|
cls,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cls.forward = types.methodType(llama_lce_forward, cls)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def patch_hf_ga(cls):
|
||||||
|
# bugfix patch for gradient accumulation
|
||||||
|
forward_source = inspect.getsource(cls.forward)
|
||||||
|
forward_source = forward_source.replace(
|
||||||
|
HF_GA_FORWARD_1, PATCHED_HF_GA_FORWARD_1
|
||||||
|
)
|
||||||
|
forward_source = forward_source.replace(
|
||||||
|
HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2
|
||||||
|
)
|
||||||
|
forward_source, _ = detab_code(forward_source)
|
||||||
|
# replace forward method with patched version
|
||||||
|
cls.forward = types.MethodType(
|
||||||
|
compile(forward_source, "<llama_forward_ga_fix>", "exec"), cls
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_auto_model():
|
||||||
|
from transformers import LlamaConfig
|
||||||
|
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING
|
||||||
|
|
||||||
|
MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM
|
||||||
|
AxolotlLlamaForCausalLM.set_forward()
|
||||||
|
|
||||||
|
return AxolotlLlamaForCausalLM
|
||||||
80
src/axolotl/monkeypatch/trainer_fsdp_optim.py
Normal file
80
src/axolotl/monkeypatch/trainer_fsdp_optim.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""
|
||||||
|
fix for FSDP optimizer save in trainer w 4.47.0
|
||||||
|
"""
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")
|
||||||
|
|
||||||
|
ORIGINAL_TRAINER_CODE = """
|
||||||
|
|
||||||
|
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_TRAINER_CODE = """
|
||||||
|
|
||||||
|
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_training_loop_code() -> str:
|
||||||
|
training_loop = inspect.getsource(
|
||||||
|
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
return training_loop
|
||||||
|
|
||||||
|
|
||||||
|
def check_training_loop_is_patchable() -> bool:
|
||||||
|
training_loop = get_training_loop_code()
|
||||||
|
training_loop, _ = detab_code(training_loop)
|
||||||
|
return ORIGINAL_TRAINER_CODE in training_loop
|
||||||
|
|
||||||
|
|
||||||
|
def patch_training_loop_for_fsdp():
|
||||||
|
"""
|
||||||
|
monkeypatch for fixing the training loop for fsdp with optimizer save
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
training_loop = get_training_loop_code()
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
|
||||||
|
training_loop
|
||||||
|
)
|
||||||
|
training_loop, _ = detab_code(training_loop)
|
||||||
|
if ORIGINAL_TRAINER_CODE not in training_loop:
|
||||||
|
return
|
||||||
|
|
||||||
|
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
|
||||||
|
training_loop = training_loop.replace(
|
||||||
|
"def _inner_training_loop(",
|
||||||
|
"def _fixed_inner_training_loop(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load imports necessary
|
||||||
|
import transformers.trainer
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(transformers.trainer):
|
||||||
|
if item in training_loop:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from transformers.trainer import ("
|
||||||
|
+ ", ".join(x for x in items_to_import)
|
||||||
|
+ ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
LOG.info("patching _inner_training_loop for fsdp optimizer save")
|
||||||
|
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||||
|
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
)
|
||||||
290
src/axolotl/monkeypatch/trainer_grad_accum.py
Normal file
290
src/axolotl/monkeypatch/trainer_grad_accum.py
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
"""
|
||||||
|
fix for FSDP gradient accumulation
|
||||||
|
see https://github.com/huggingface/transformers/pull/35128
|
||||||
|
"""
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from transformers import LlamaForCausalLM, Trainer
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
|
||||||
|
|
||||||
|
ORIGINAL_CONTEXT_CODE = """
|
||||||
|
with self.compute_loss_context_manager():
|
||||||
|
if self.model_accepts_loss_kwargs:
|
||||||
|
loss = self.compute_loss(model, inputs)
|
||||||
|
else:
|
||||||
|
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_CONTEXT_CODE = """
|
||||||
|
with self.compute_loss_context_manager():
|
||||||
|
if self.model_accepts_loss_kwargs:
|
||||||
|
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||||
|
else:
|
||||||
|
loss = self.compute_loss(model, inputs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
ORIGINAL_LLAMA_FCLM_CODE = """
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_LLAMA_FCLM_CODE = """
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
|
||||||
|
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_training_step_code() -> str:
|
||||||
|
training_step = inspect.getsource(
|
||||||
|
Trainer.training_step # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
return training_step
|
||||||
|
|
||||||
|
|
||||||
|
def check_training_step_is_patchable() -> bool:
|
||||||
|
training_step = get_training_step_code()
|
||||||
|
training_step, _ = detab_code(training_step)
|
||||||
|
return ORIGINAL_CONTEXT_CODE in training_step
|
||||||
|
|
||||||
|
|
||||||
|
def patch_training_step_for_ga():
|
||||||
|
"""
|
||||||
|
monkeypatch for fixing the training loop for gradient accumulation
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
training_step = get_training_step_code()
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
Trainer._original_training_step = training_step # pylint: disable=protected-access
|
||||||
|
training_step, _ = detab_code(training_step)
|
||||||
|
if ORIGINAL_CONTEXT_CODE not in training_step:
|
||||||
|
return
|
||||||
|
# assert (
|
||||||
|
# ORIGINAL_CONTEXT_CODE in training_step
|
||||||
|
# ), "Original training_step code not found"
|
||||||
|
|
||||||
|
training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
|
||||||
|
training_step = training_step.replace(
|
||||||
|
"def training_step(",
|
||||||
|
"def _fixed_training_step(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load imports necessary
|
||||||
|
import transformers.trainer
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(transformers.trainer):
|
||||||
|
if item in training_step:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from transformers.trainer import ("
|
||||||
|
+ ", ".join(x for x in items_to_import)
|
||||||
|
+ ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
LOG.info("patching training_step")
|
||||||
|
Trainer.training_step = ( # pylint: disable=protected-access
|
||||||
|
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_forward_code() -> str:
|
||||||
|
forward = inspect.getsource(
|
||||||
|
LlamaForCausalLM.forward # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def check_forward_is_patchable() -> bool:
|
||||||
|
forward = get_model_forward_code()
|
||||||
|
forward, _ = detab_code(forward)
|
||||||
|
return ORIGINAL_LLAMA_FCLM_CODE in forward
|
||||||
|
|
||||||
|
|
||||||
|
def patch_forward_for_ga():
|
||||||
|
"""
|
||||||
|
monkeypatch for fixing the training loop for gradient accumulation
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
forward = get_model_forward_code()
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
||||||
|
forward, _ = detab_code(forward)
|
||||||
|
if ORIGINAL_LLAMA_FCLM_CODE not in forward:
|
||||||
|
return
|
||||||
|
# assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"
|
||||||
|
|
||||||
|
forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE)
|
||||||
|
forward = forward.replace(
|
||||||
|
"def forward(",
|
||||||
|
"def _fixed_forward(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load imports necessary
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(transformers.models.llama.modeling_llama):
|
||||||
|
if item in forward:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from transformers.models.llama.modeling_llama import ("
|
||||||
|
+ ", ".join(x for x in items_to_import)
|
||||||
|
+ ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
LOG.info("patching forward")
|
||||||
|
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
|
||||||
|
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ORIGINAL_TRAINER_CODE = """
|
||||||
|
context = (
|
||||||
|
functools.partial(self.accelerator.no_sync, model=model)
|
||||||
|
if i != len(batch_samples) - 1
|
||||||
|
else contextlib.nullcontext
|
||||||
|
)
|
||||||
|
with context():
|
||||||
|
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_TRAINER_CODE = """
|
||||||
|
disable_deepspeed_no_sync = (
|
||||||
|
self.accelerator.distributed_type == DistributedType.DEEPSPEED
|
||||||
|
# and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients()
|
||||||
|
)
|
||||||
|
context = (
|
||||||
|
functools.partial(self.accelerator.no_sync, model=model)
|
||||||
|
if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync
|
||||||
|
else contextlib.nullcontext
|
||||||
|
)
|
||||||
|
with context():
|
||||||
|
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_training_loop_code() -> str:
|
||||||
|
training_loop = inspect.getsource(
|
||||||
|
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
return training_loop
|
||||||
|
|
||||||
|
|
||||||
|
def check_training_loop_is_patchable() -> bool:
|
||||||
|
training_loop = get_training_loop_code()
|
||||||
|
training_loop, _ = detab_code(training_loop)
|
||||||
|
return ORIGINAL_TRAINER_CODE in training_loop
|
||||||
|
|
||||||
|
|
||||||
|
def patch_training_loop_for_deepspeed_0_16_x():
|
||||||
|
"""
|
||||||
|
monkeypatch for fixing the training loop for deepspeed GA
|
||||||
|
|
||||||
|
see https://github.com/huggingface/transformers/pull/35157
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
training_loop = get_training_loop_code()
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
|
||||||
|
training_loop
|
||||||
|
)
|
||||||
|
training_loop, _ = detab_code(training_loop)
|
||||||
|
if ORIGINAL_TRAINER_CODE not in training_loop:
|
||||||
|
return
|
||||||
|
|
||||||
|
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
|
||||||
|
training_loop = training_loop.replace(
|
||||||
|
"def _inner_training_loop(",
|
||||||
|
"def _fixed_inner_training_loop(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load imports necessary
|
||||||
|
import transformers.trainer
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(transformers.trainer):
|
||||||
|
if item in training_loop:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from transformers.trainer import ("
|
||||||
|
+ ", ".join(x for x in items_to_import)
|
||||||
|
+ ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
LOG.info("patching _inner_training_loop for fsdp optimizer save")
|
||||||
|
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||||
|
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
)
|
||||||
@@ -9,10 +9,7 @@ import torch
|
|||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from peft import PeftModelForCausalLM
|
from peft import PeftModelForCausalLM
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||||
LlamaFlashAttention2,
|
|
||||||
LlamaForCausalLM,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
||||||
|
|
||||||
@@ -55,11 +52,6 @@ def original_apply_o(self, hidden_states):
|
|||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
def get_forward_code() -> str:
|
|
||||||
forward = inspect.getsource(LlamaForCausalLM.forward)
|
|
||||||
return forward
|
|
||||||
|
|
||||||
|
|
||||||
def get_self_attn_code() -> str:
|
def get_self_attn_code() -> str:
|
||||||
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
||||||
return forward
|
return forward
|
||||||
@@ -102,12 +94,22 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
|
|||||||
|
|
||||||
|
|
||||||
def detab_code(code: str) -> Tuple[str, str]:
|
def detab_code(code: str) -> Tuple[str, str]:
|
||||||
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
try:
|
||||||
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
|
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
||||||
|
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
|
||||||
|
except AttributeError:
|
||||||
|
return code, ""
|
||||||
return code, spaces
|
return code, spaces
|
||||||
|
|
||||||
|
|
||||||
|
self_attn_lora_patched = False # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
def patch_self_attn_lora():
|
def patch_self_attn_lora():
|
||||||
|
global self_attn_lora_patched # pylint: disable=global-statement
|
||||||
|
if self_attn_lora_patched:
|
||||||
|
# prevent patching multiple times
|
||||||
|
return
|
||||||
self_attn_forward = get_self_attn_code()
|
self_attn_forward = get_self_attn_code()
|
||||||
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
|
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
|
||||||
self_attn_forward
|
self_attn_forward
|
||||||
@@ -139,6 +141,7 @@ def patch_self_attn_lora():
|
|||||||
globals(),
|
globals(),
|
||||||
)
|
)
|
||||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
self_attn_lora_patched = True
|
||||||
LOG.info("patching unsloth attn lora", main_process_only=True)
|
LOG.info("patching unsloth attn lora", main_process_only=True)
|
||||||
LlamaFlashAttention2.forward = (
|
LlamaFlashAttention2.forward = (
|
||||||
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
max_length = self.prompter.max_length
|
||||||
|
|
||||||
self.messages = "chosen_messages"
|
self.messages = "chosen_messages"
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
prompt[self.messages] = []
|
prompt[self.messages] = []
|
||||||
@@ -39,6 +41,16 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
|||||||
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
|
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
|
||||||
chosen_tokenized = super().tokenize_prompt(prompt)
|
chosen_tokenized = super().tokenize_prompt(prompt)
|
||||||
|
|
||||||
|
if len(chosen_tokenized["input_ids"]) > max_length:
|
||||||
|
LOG.warning(
|
||||||
|
f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
|
||||||
|
)
|
||||||
|
|
||||||
|
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
|
||||||
|
chosen_tokenized["attention_mask"] = chosen_tokenized["attention_mask"][
|
||||||
|
:max_length
|
||||||
|
]
|
||||||
|
|
||||||
self.messages = "rejected_messages"
|
self.messages = "rejected_messages"
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
prompt[self.messages] = []
|
prompt[self.messages] = []
|
||||||
@@ -52,6 +64,18 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
|||||||
)
|
)
|
||||||
rejected_tokenized = super().tokenize_prompt(prompt)
|
rejected_tokenized = super().tokenize_prompt(prompt)
|
||||||
|
|
||||||
|
if len(rejected_tokenized["input_ids"]) > max_length:
|
||||||
|
LOG.warning(
|
||||||
|
f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
|
||||||
|
)
|
||||||
|
|
||||||
|
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
|
||||||
|
:max_length
|
||||||
|
]
|
||||||
|
rejected_tokenized["attention_mask"] = rejected_tokenized["attention_mask"][
|
||||||
|
:max_length
|
||||||
|
]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids_chosen": chosen_tokenized["input_ids"],
|
"input_ids_chosen": chosen_tokenized["input_ids"],
|
||||||
"attention_mask_chosen": chosen_tokenized["attention_mask"],
|
"attention_mask_chosen": chosen_tokenized["attention_mask"],
|
||||||
@@ -80,9 +104,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
"roles": ds_cfg.get("roles"),
|
"roles": ds_cfg.get("roles"),
|
||||||
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||||
"max_length": cfg.sequence_len + 1
|
"max_length": (
|
||||||
if not cfg.reward_model
|
cfg.sequence_len + 1 if not cfg.reward_model else cfg.sequence_len
|
||||||
else cfg.sequence_len,
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
strategy_params = {
|
strategy_params = {
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
"gpt": "assistant",
|
"gpt": "assistant",
|
||||||
"system": "system",
|
"system": "system",
|
||||||
}
|
}
|
||||||
|
|
||||||
self.message_field_role = message_field_role
|
self.message_field_role = message_field_role
|
||||||
self.message_field_content = message_field_content
|
self.message_field_content = message_field_content
|
||||||
self.message_field_training = message_field_training
|
self.message_field_training = message_field_training
|
||||||
@@ -53,21 +54,9 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
self.drop_system_message = drop_system_message
|
self.drop_system_message = drop_system_message
|
||||||
|
|
||||||
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
|
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
|
||||||
turns = [
|
|
||||||
{
|
|
||||||
"role": self.roles[t[self.message_field_role]],
|
|
||||||
"content": t[self.message_field_content],
|
|
||||||
"training": t.get(self.message_field_training, None),
|
|
||||||
}
|
|
||||||
for t in conversation
|
|
||||||
]
|
|
||||||
|
|
||||||
if self.drop_system_message and turns[0]["role"] == "system":
|
|
||||||
turns = turns[1:]
|
|
||||||
|
|
||||||
if self.processor:
|
if self.processor:
|
||||||
text = self.processor.apply_chat_template(
|
text = self.processor.apply_chat_template(
|
||||||
turns,
|
conversation,
|
||||||
chat_template=self.chat_template,
|
chat_template=self.chat_template,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
@@ -76,8 +65,6 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
text=text,
|
text=text,
|
||||||
images=images,
|
images=images,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
truncation=True,
|
|
||||||
max_length=self.max_length,
|
|
||||||
)
|
)
|
||||||
# workaround since processor works in batches instead of single examples
|
# workaround since processor works in batches instead of single examples
|
||||||
for k, val in batch.items():
|
for k, val in batch.items():
|
||||||
@@ -88,9 +75,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
return self.tokenizer.apply_chat_template(
|
return self.tokenizer.apply_chat_template(
|
||||||
turns,
|
conversation,
|
||||||
truncation=True,
|
|
||||||
max_length=self.max_length,
|
|
||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
chat_template=self.chat_template,
|
chat_template=self.chat_template,
|
||||||
)
|
)
|
||||||
@@ -215,7 +200,14 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
train_on_eos=None,
|
train_on_eos=None,
|
||||||
):
|
):
|
||||||
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
||||||
self.roles_to_train = roles_to_train if roles_to_train is not None else []
|
|
||||||
|
self.roles_to_train = []
|
||||||
|
if roles_to_train:
|
||||||
|
# map roles if exist in prompter.roles else use the role as is
|
||||||
|
self.roles_to_train = [
|
||||||
|
prompter.roles.get(role, role) for role in roles_to_train
|
||||||
|
]
|
||||||
|
|
||||||
self.train_on_eos = train_on_eos
|
self.train_on_eos = train_on_eos
|
||||||
self.images = "images"
|
self.images = "images"
|
||||||
|
|
||||||
@@ -262,30 +254,28 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
return tokenized_prompt
|
return tokenized_prompt
|
||||||
|
|
||||||
turns = prompt[self.messages]
|
turns = self.get_conversation_thread(prompt)
|
||||||
input_ids = self.prompter.build_prompt(turns)
|
input_ids = self.prompter.build_prompt(turns)
|
||||||
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
||||||
|
|
||||||
last_eos_idx = -1
|
last_eos_idx = -1
|
||||||
for index, turn in enumerate(turns):
|
for index, turn in enumerate(turns):
|
||||||
role = turn.get(self.prompter.message_field_role)
|
role = turn.get("role")
|
||||||
content = turn.get(self.prompter.message_field_content)
|
content = turn.get("content")
|
||||||
train_turn = turn.get(self.prompter.message_field_training)
|
train_turn = turn.get("training")
|
||||||
train_detail = turn.get(self.prompter.message_field_training_detail)
|
train_detail = turn.get("training_detail")
|
||||||
|
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
|
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
|
||||||
)
|
)
|
||||||
|
|
||||||
should_train = (
|
should_train = None
|
||||||
train_turn
|
if train_turn is not None:
|
||||||
if train_turn is not None
|
should_train = train_turn
|
||||||
else (
|
elif train_detail is not None:
|
||||||
bool(train_detail is not None)
|
should_train = bool(train_detail)
|
||||||
if train_detail is not None
|
else:
|
||||||
else self.train_on_inputs or role in self.roles_to_train
|
should_train = self.train_on_inputs or role in self.roles_to_train
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.debug(f"Should train: {should_train}")
|
LOG.debug(f"Should train: {should_train}")
|
||||||
|
|
||||||
@@ -293,6 +283,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
conversation_ids=input_ids, turn=index, turn_content=turn
|
conversation_ids=input_ids, turn=index, turn_content=turn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if turn_start_idx == -1 or turn_end_idx == -1:
|
||||||
|
LOG.warning(f"Failed to find boundaries for turn {index}")
|
||||||
|
|
||||||
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
||||||
|
|
||||||
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
|
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
|
||||||
@@ -313,7 +306,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
labels[turn_start_idx:turn_end_idx] = input_ids[
|
labels[turn_start_idx:turn_end_idx] = input_ids[
|
||||||
turn_start_idx:turn_end_idx
|
turn_start_idx:turn_end_idx
|
||||||
]
|
]
|
||||||
LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}")
|
LOG.debug(
|
||||||
|
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
|
||||||
|
)
|
||||||
|
|
||||||
LOG.debug(f"Labels after processing turn {index}: {labels}")
|
LOG.debug(f"Labels after processing turn {index}: {labels}")
|
||||||
|
|
||||||
@@ -351,52 +346,73 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
return i
|
return i
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
def find_turn(self, conversation_ids, turn, turn_content):
|
def find_turn(self, conversation_ids: list[int], turn: int, turn_content: dict):
|
||||||
"""
|
"""
|
||||||
Locate the starting and ending indices of the specified turn in a conversation.
|
Locate the starting and ending indices of the specified turn in a conversation.
|
||||||
|
|
||||||
Args:
|
|
||||||
conversation_ids (list[int]): Token IDs representing the conversation.
|
|
||||||
turn (int): The turn number to locate (based on EOS tokens).
|
|
||||||
turn_content (str): String containing the content of the turn.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (start_idx, end_idx) indices of the start and end of the turn content.
|
|
||||||
Returns (-1, -1) if the turn content is not found.
|
|
||||||
"""
|
"""
|
||||||
content = turn_content.get(self.prompter.message_field_content, "")
|
content = turn_content.get("content")
|
||||||
content_ids = self.tokenizer.encode(content, add_special_tokens=False)
|
content_ids = self.tokenizer.encode(content, add_special_tokens=False)
|
||||||
|
|
||||||
eos_token_id = self.tokenizer.eos_token_id
|
LOG.debug(f"content_ids (length {len(content_ids)}): {content_ids}")
|
||||||
eos_count = 0
|
|
||||||
start_search_idx = 0
|
|
||||||
|
|
||||||
# Locate the starting index after the specified number of EOS tokens
|
if not content_ids:
|
||||||
for i, token_id in enumerate(conversation_ids):
|
LOG.warning(f"Empty content for turn {turn}")
|
||||||
if token_id == eos_token_id:
|
return -1, -1
|
||||||
eos_count += 1
|
|
||||||
if eos_count == turn:
|
|
||||||
start_search_idx = (
|
|
||||||
i + 1
|
|
||||||
) # Start searching after the specified turn's EOS token
|
|
||||||
break
|
|
||||||
|
|
||||||
# Find the start index of the content within the conversation
|
# For first turn, start from beginning
|
||||||
start_idx = -1
|
if turn == 0:
|
||||||
for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1):
|
start_search_idx = 0
|
||||||
if conversation_ids[i : i + len(content_ids)] == content_ids:
|
|
||||||
start_idx = i
|
|
||||||
break
|
|
||||||
|
|
||||||
if start_idx != -1:
|
|
||||||
end_idx = start_idx + len(content_ids)
|
|
||||||
else:
|
else:
|
||||||
end_idx = -1
|
# For subsequent turns, find the previous EOS token
|
||||||
|
eos_token_id = self.tokenizer.eos_token_id
|
||||||
|
eos_count = 0
|
||||||
|
start_search_idx = 0
|
||||||
|
|
||||||
return start_idx, end_idx
|
for i, token_id in enumerate(conversation_ids):
|
||||||
|
if token_id == eos_token_id:
|
||||||
|
eos_count += 1
|
||||||
|
if eos_count == turn: # Find the nth EOS token where n = turn
|
||||||
|
start_search_idx = i + 1
|
||||||
|
break
|
||||||
|
|
||||||
|
# we can optimize this to only search for a few tokens from start_search_idx
|
||||||
|
# but it would risk missing the content if it's not found within the first few tokens or
|
||||||
|
# if start_search_idx cannot be found above.
|
||||||
|
last_index = len(conversation_ids) - len(content_ids) + 1
|
||||||
|
|
||||||
|
if last_index < start_search_idx:
|
||||||
|
LOG.warning(
|
||||||
|
f"last_index to search is less than start_search_idx for turn {turn}"
|
||||||
|
)
|
||||||
|
return -1, -1
|
||||||
|
|
||||||
|
# Search for content starting from start_search_idx
|
||||||
|
first_elem = content_ids[0]
|
||||||
|
for i in range(start_search_idx, last_index):
|
||||||
|
# Quick check of first element before doing full comparison
|
||||||
|
if conversation_ids[i] == first_elem:
|
||||||
|
# Check if the rest of the content matches
|
||||||
|
if conversation_ids[i : i + len(content_ids)] == content_ids:
|
||||||
|
LOG.debug(f"Found turn {turn} content at position {i}")
|
||||||
|
return i, i + len(content_ids)
|
||||||
|
|
||||||
|
return -1, -1
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
def get_conversation_thread(self, prompt):
|
||||||
return prompt[self.messages]
|
turns = [
|
||||||
|
{
|
||||||
|
"role": self.prompter.roles[t[self.prompter.message_field_role]],
|
||||||
|
"content": t[self.prompter.message_field_content],
|
||||||
|
"training": t.get(self.prompter.message_field_training),
|
||||||
|
"training_detail": t.get(self.prompter.message_field_training_detail),
|
||||||
|
}
|
||||||
|
for t in prompt[self.messages]
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.prompter.drop_system_message and turns[0]["role"] == "system":
|
||||||
|
turns = turns[1:]
|
||||||
|
|
||||||
|
return turns
|
||||||
|
|
||||||
def get_images(self, prompt):
|
def get_images(self, prompt):
|
||||||
return prompt.get(self.images, None)
|
return prompt.get(self.images, None)
|
||||||
|
|||||||
@@ -259,14 +259,7 @@ def train(
|
|||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
if not cfg.hub_model_id:
|
if not cfg.hub_model_id:
|
||||||
from huggingface_hub import HfApi
|
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Check to make sure the base model is from HuggingFace not a local directory
|
|
||||||
hf_api = HfApi()
|
|
||||||
hf_api.model_info(cfg.base_model)
|
|
||||||
|
|
||||||
model_card_kwarg = {
|
model_card_kwarg = {
|
||||||
"model_name": cfg.output_dir.lstrip("./")
|
"model_name": cfg.output_dir.lstrip("./")
|
||||||
.encode("utf-8")
|
.encode("utf-8")
|
||||||
@@ -274,16 +267,22 @@ def train(
|
|||||||
}
|
}
|
||||||
if cfg.datasets is not None:
|
if cfg.datasets is not None:
|
||||||
if cfg.rl is not None or cfg.reward_model:
|
if cfg.rl is not None or cfg.reward_model:
|
||||||
model_card_kwarg["dataset_name"] = [
|
dataset_tags = [
|
||||||
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
]
|
]
|
||||||
|
if dataset_tags:
|
||||||
|
# guard as create_model_card may fail if dataset_tags is empty list
|
||||||
|
model_card_kwarg["dataset_name"] = dataset_tags
|
||||||
else:
|
else:
|
||||||
model_card_kwarg["dataset_tags"] = [
|
dataset_tags = [
|
||||||
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
]
|
]
|
||||||
|
if dataset_tags:
|
||||||
|
# guard as create_model_card may fail if dataset_tags is empty list
|
||||||
|
model_card_kwarg["dataset_tags"] = dataset_tags
|
||||||
|
|
||||||
trainer.create_model_card(**model_card_kwarg)
|
trainer.create_model_card(**model_card_kwarg)
|
||||||
except (AttributeError, UnicodeDecodeError, RepositoryNotFoundError):
|
except (AttributeError, UnicodeDecodeError):
|
||||||
pass
|
pass
|
||||||
elif cfg.hub_model_id:
|
elif cfg.hub_model_id:
|
||||||
# defensively push to the hub to ensure the model card is updated
|
# defensively push to the hub to ensure the model card is updated
|
||||||
|
|||||||
@@ -66,10 +66,7 @@ class EvalFirstStepCallback(
|
|||||||
control: TrainerControl,
|
control: TrainerControl,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if (
|
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
|
||||||
args.evaluation_strategy == IntervalStrategy.STEPS
|
|
||||||
and state.global_step == 1
|
|
||||||
):
|
|
||||||
control.should_evaluate = True
|
control.should_evaluate = True
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -22,7 +22,6 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
processor: ProcessorMixin
|
processor: ProcessorMixin
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
chat_template: Optional[str] = None
|
chat_template: Optional[str] = None
|
||||||
chat_template_type: Optional[str] = None
|
|
||||||
packing: bool = False
|
packing: bool = False
|
||||||
max_images: int = -1
|
max_images: int = -1
|
||||||
padding: Union[bool, str, PaddingStrategy] = True
|
padding: Union[bool, str, PaddingStrategy] = True
|
||||||
@@ -36,187 +35,142 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
self, examples: list[Union[list[int], Any, dict[str, Any]]]
|
self, examples: list[Union[list[int], Any, dict[str, Any]]]
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
# Handle dict or lists with proper padding and conversion to tensor.
|
# Handle dict or lists with proper padding and conversion to tensor.
|
||||||
|
|
||||||
return self.__class__.process_rows(
|
return self.__class__.process_rows(
|
||||||
examples,
|
examples, self.processor, self.chat_template, self.max_images
|
||||||
self.processor,
|
|
||||||
self.chat_template,
|
|
||||||
self.max_images,
|
|
||||||
chat_template_type=self.chat_template_type,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def preprocess(examples: list[dict]) -> list[dict]:
|
def process_rows(examples, processor, chat_template, max_images, length_only=False):
|
||||||
"""
|
|
||||||
Preprocess conversation examples to ensure consistent format.
|
|
||||||
Converts different conversation formats to OpenAI format with 'messages'.
|
|
||||||
Supports two formats:
|
|
||||||
1. OpenAI format with 'messages'
|
|
||||||
2. Legacy format with 'conversations'
|
|
||||||
|
|
||||||
Args:
|
|
||||||
examples: list of conversation dictionaries
|
|
||||||
Returns:
|
|
||||||
dict in OpenAI format with 'messages' key
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the conversation format is not supported
|
|
||||||
"""
|
|
||||||
role_mapping = {
|
|
||||||
"human": "user",
|
|
||||||
"gpt": "assistant",
|
|
||||||
}
|
|
||||||
|
|
||||||
def normalize_role(role: str) -> str:
|
|
||||||
"""Normalize role names to OpenAI format. Default to original role if not found."""
|
|
||||||
return role_mapping.get(role, role)
|
|
||||||
|
|
||||||
def convert_legacy_format(example: dict) -> dict:
|
|
||||||
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": normalize_role(convo["from"]),
|
|
||||||
"content": convo["value"],
|
|
||||||
}
|
|
||||||
for convo in example["conversations"]
|
|
||||||
]
|
|
||||||
|
|
||||||
# Create new dict without 'conversations' key
|
|
||||||
result = deepcopy(example)
|
|
||||||
result.pop("conversations")
|
|
||||||
return {"messages": messages, **result}
|
|
||||||
|
|
||||||
processed_examples = []
|
|
||||||
for example in examples:
|
|
||||||
# OpenAI format
|
|
||||||
if "messages" in example:
|
|
||||||
processed_examples.append(example)
|
|
||||||
|
|
||||||
# Legacy format
|
|
||||||
elif "conversations" in example:
|
|
||||||
processed_examples.append(convert_legacy_format(example))
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Only `messages` and `conversations` message keys are currently supported."
|
|
||||||
)
|
|
||||||
|
|
||||||
return processed_examples
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def process_images(examples, max_images):
|
|
||||||
"""
|
|
||||||
Process images from examples, ensuring consistency in image presence and applying max_images limit.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
examples: List of dictionaries that may contain 'images' key
|
|
||||||
max_images: Maximum number of images to keep per example (0 means no limit)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Either None (if no images) or List[Image objects] (if all examples have images)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If there's a mix of None and non-None images
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_image(example):
|
|
||||||
if "images" not in example:
|
|
||||||
return None
|
|
||||||
images = example["images"]
|
|
||||||
if isinstance(images, str):
|
|
||||||
return Image.open(images)
|
|
||||||
return images
|
|
||||||
|
|
||||||
images = [get_image(example) for example in examples]
|
|
||||||
|
|
||||||
# Count None and non-None images
|
|
||||||
none_count = sum(1 for img in images if img is None)
|
|
||||||
|
|
||||||
# All images are None
|
|
||||||
if none_count == len(images):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Mix of None and non-None images
|
|
||||||
if none_count > 0:
|
|
||||||
raise ValueError(
|
|
||||||
"All images should be either None or not None. "
|
|
||||||
"Please provide images for all examples or None."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply max_images limit if specified
|
|
||||||
if max_images > 0:
|
|
||||||
images = [
|
|
||||||
(
|
|
||||||
img_batch[:max_images]
|
|
||||||
if isinstance(img_batch, (list, tuple))
|
|
||||||
else img_batch
|
|
||||||
)
|
|
||||||
for img_batch in images
|
|
||||||
]
|
|
||||||
|
|
||||||
return images
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def pixtral_chat_conversion(messages):
|
|
||||||
is_single_message = not isinstance(messages, list)
|
|
||||||
if is_single_message:
|
|
||||||
messages = [messages]
|
|
||||||
|
|
||||||
for i, message in enumerate(messages):
|
|
||||||
if message["role"] == "user":
|
|
||||||
for j, content in enumerate(message["content"]):
|
|
||||||
if "type" in content and content["type"] == "text":
|
|
||||||
messages[i]["content"][j] = {
|
|
||||||
"type": "text",
|
|
||||||
"content": content["text"],
|
|
||||||
}
|
|
||||||
|
|
||||||
if message["role"] == "assistant":
|
|
||||||
messages[i]["content"] = message["content"][0]["text"]
|
|
||||||
|
|
||||||
if is_single_message:
|
|
||||||
return messages[0]
|
|
||||||
return messages
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def process_rows(
|
|
||||||
examples,
|
|
||||||
processor,
|
|
||||||
chat_template,
|
|
||||||
max_images,
|
|
||||||
length_only=False,
|
|
||||||
chat_template_type=None,
|
|
||||||
):
|
|
||||||
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
||||||
# see also DataCollatorWithFlattening and DefaultDataCollator
|
# see also DataCollatorWithFlattening and DefaultDataCollator
|
||||||
|
|
||||||
# *** This is COPIED from the trl example sft_vlm.py code ***
|
# *** This is COPIED from the trl example sft_vlm.py code ***
|
||||||
# use this as a starting point
|
# use this as a starting point
|
||||||
|
|
||||||
|
def _preprocess(examples: list[dict]) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Preprocess conversation examples to ensure consistent format.
|
||||||
|
|
||||||
|
Converts different conversation formats to OpenAI format with 'messages'.
|
||||||
|
Supports two formats:
|
||||||
|
1. OpenAI format with 'messages'
|
||||||
|
2. Legacy format with 'conversations'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
examples: list of conversation dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict in OpenAI format with 'messages' key
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the conversation format is not supported
|
||||||
|
"""
|
||||||
|
role_mapping = {
|
||||||
|
"human": "user",
|
||||||
|
"gpt": "assistant",
|
||||||
|
}
|
||||||
|
|
||||||
|
def normalize_role(role: str) -> str:
|
||||||
|
"""Normalize role names to OpenAI format. Default to original role if not found."""
|
||||||
|
return role_mapping.get(role, role)
|
||||||
|
|
||||||
|
def convert_legacy_format(example: dict) -> dict:
|
||||||
|
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": normalize_role(convo["from"]),
|
||||||
|
"content": convo["value"],
|
||||||
|
}
|
||||||
|
for convo in example["conversations"]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create new dict without 'conversations' key
|
||||||
|
result = deepcopy(example)
|
||||||
|
result.pop("conversations")
|
||||||
|
return {"messages": messages, **result}
|
||||||
|
|
||||||
|
processed_examples = []
|
||||||
|
for example in examples:
|
||||||
|
# OpenAI format
|
||||||
|
if "messages" in example:
|
||||||
|
processed_examples.append(example)
|
||||||
|
|
||||||
|
# Legacy format
|
||||||
|
elif "conversations" in example:
|
||||||
|
processed_examples.append(convert_legacy_format(example))
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Only `messages` and `conversations` message keys are currently supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
return processed_examples
|
||||||
|
|
||||||
|
def _process_images(examples, max_images):
|
||||||
|
"""
|
||||||
|
Process images from examples, ensuring consistency in image presence and applying max_images limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
examples: List of dictionaries that may contain 'images' key
|
||||||
|
max_images: Maximum number of images to keep per example (0 means no limit)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either None (if no images) or List[Image objects] (if all examples have images)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If there's a mix of None and non-None images
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_image(example):
|
||||||
|
if "images" not in example:
|
||||||
|
return None
|
||||||
|
images = example["images"]
|
||||||
|
if isinstance(images, str):
|
||||||
|
return Image.open(images)
|
||||||
|
return images
|
||||||
|
|
||||||
|
images = [get_image(example) for example in examples]
|
||||||
|
|
||||||
|
# Count None and non-None images
|
||||||
|
none_count = sum(1 for img in images if img is None)
|
||||||
|
|
||||||
|
# All images are None
|
||||||
|
if none_count == len(images):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Mix of None and non-None images
|
||||||
|
if none_count > 0:
|
||||||
|
raise ValueError(
|
||||||
|
"All images should be either None or not None. "
|
||||||
|
"Please provide images for all examples or None."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply max_images limit if specified
|
||||||
|
if max_images > 0:
|
||||||
|
images = [
|
||||||
|
(
|
||||||
|
img_batch[:max_images]
|
||||||
|
if isinstance(img_batch, (list, tuple))
|
||||||
|
else img_batch
|
||||||
|
)
|
||||||
|
for img_batch in images
|
||||||
|
]
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
# Preprocess the examples
|
# Preprocess the examples
|
||||||
examples = __class__.preprocess(examples)
|
examples = _preprocess(examples)
|
||||||
|
|
||||||
# Get the texts and images, and apply the chat template
|
# Get the texts and images, and apply the chat template
|
||||||
if chat_template_type == "pixtral":
|
texts = [
|
||||||
texts = [
|
processor.apply_chat_template(
|
||||||
processor.apply_chat_template(
|
example["messages"], chat_template=chat_template, tokenize=False
|
||||||
__class__.pixtral_chat_conversion(example["messages"]),
|
)
|
||||||
chat_template=chat_template,
|
for example in examples
|
||||||
tokenize=False,
|
]
|
||||||
)
|
|
||||||
for example in examples
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
texts = [
|
|
||||||
processor.apply_chat_template(
|
|
||||||
example["messages"], chat_template=chat_template, tokenize=False
|
|
||||||
)
|
|
||||||
for example in examples
|
|
||||||
]
|
|
||||||
|
|
||||||
images = __class__.process_images(examples, max_images=max_images)
|
images = _process_images(examples, max_images=max_images)
|
||||||
if chat_template_type == "llava":
|
|
||||||
# LLava1.5 does not support multiple images
|
|
||||||
images = [image[0] for image in images]
|
|
||||||
|
|
||||||
# Tokenize the texts and process the images
|
# Tokenize the texts and process the images
|
||||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||||
@@ -225,12 +179,9 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
labels = batch["input_ids"].clone()
|
labels = batch["input_ids"].clone()
|
||||||
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
||||||
# Ignore the image token index in the loss computation (model specific)
|
# Ignore the image token index in the loss computation (model specific)
|
||||||
if chat_template_type == "qwen2_vl":
|
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||||
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
|
processor.image_token
|
||||||
else:
|
)
|
||||||
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
|
||||||
processor.image_token
|
|
||||||
)
|
|
||||||
labels[labels == image_token_id] = -100
|
labels[labels == image_token_id] = -100
|
||||||
batch["labels"] = labels
|
batch["labels"] = labels
|
||||||
|
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ def normalize_config(cfg):
|
|||||||
|
|
||||||
cfg.is_multimodal = (
|
cfg.is_multimodal = (
|
||||||
hasattr(model_config, "model_type")
|
hasattr(model_config, "model_type")
|
||||||
and model_config.model_type in ["llava", "mllama", "qwen2_vl", "qwen2_5_vl"]
|
and model_config.model_type in ["llava", "mllama"]
|
||||||
or any(
|
or any(
|
||||||
multimodal_name in cfg.base_model.lower()
|
multimodal_name in cfg.base_model.lower()
|
||||||
for multimodal_name in [
|
for multimodal_name in [
|
||||||
@@ -145,12 +145,7 @@ def normalize_config(cfg):
|
|||||||
cfg.processor_config = (
|
cfg.processor_config = (
|
||||||
cfg.processor_config or cfg.base_model_config or cfg.base_model
|
cfg.processor_config or cfg.base_model_config or cfg.base_model
|
||||||
)
|
)
|
||||||
|
model_config = model_config.text_config
|
||||||
try:
|
|
||||||
model_config = model_config.text_config
|
|
||||||
except AttributeError:
|
|
||||||
# for qwen2_vl
|
|
||||||
model_config = model_config.get_text_config()
|
|
||||||
|
|
||||||
cfg.model_config_type = model_config.model_type
|
cfg.model_config_type = model_config.model_type
|
||||||
|
|
||||||
@@ -158,7 +153,7 @@ def normalize_config(cfg):
|
|||||||
cfg.is_llama_derived_model = (
|
cfg.is_llama_derived_model = (
|
||||||
(
|
(
|
||||||
hasattr(model_config, "model_type")
|
hasattr(model_config, "model_type")
|
||||||
and model_config.model_type == ["llama", "mllama_text_model"]
|
and model_config.model_type in ["llama", "mllama_text_model"]
|
||||||
)
|
)
|
||||||
or cfg.is_llama_derived_model
|
or cfg.is_llama_derived_model
|
||||||
or "llama" in cfg.base_model.lower()
|
or "llama" in cfg.base_model.lower()
|
||||||
|
|||||||
@@ -51,7 +51,6 @@ class ChatTemplate(str, Enum):
|
|||||||
cohere = "cohere" # pylint: disable=invalid-name
|
cohere = "cohere" # pylint: disable=invalid-name
|
||||||
llama3 = "llama3" # pylint: disable=invalid-name
|
llama3 = "llama3" # pylint: disable=invalid-name
|
||||||
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||||
llava = "llava" # pylint: disable=invalid-name
|
|
||||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||||
phi_35 = "phi_35" # pylint: disable=invalid-name
|
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||||
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||||
@@ -61,8 +60,6 @@ class ChatTemplate(str, Enum):
|
|||||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||||
exaone = "exaone" # pylint: disable=invalid-name
|
exaone = "exaone" # pylint: disable=invalid-name
|
||||||
metharme = "metharme" # pylint: disable=invalid-name
|
metharme = "metharme" # pylint: disable=invalid-name
|
||||||
pixtral = "pixtral" # pylint: disable=invalid-name
|
|
||||||
qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
class DeprecatedParameters(BaseModel):
|
class DeprecatedParameters(BaseModel):
|
||||||
@@ -682,6 +679,7 @@ class AxolotlInputConfig(
|
|||||||
default=False
|
default=False
|
||||||
)
|
)
|
||||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
activation_offloading: Optional[bool] = None
|
||||||
|
|
||||||
unfrozen_parameters: Optional[List[str]] = None
|
unfrozen_parameters: Optional[List[str]] = None
|
||||||
|
|
||||||
@@ -1478,6 +1476,27 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_kto_config(cls, data):
|
||||||
|
if data.get("rl") == "kto":
|
||||||
|
if data.get("sample_packing") or data.get("eval_sample_packing"):
|
||||||
|
raise ValueError("sample_packing is not supported with kto")
|
||||||
|
|
||||||
|
if data.get("remove_unused_columns") is not False:
|
||||||
|
raise ValueError("Set `remove_unused_columns: False` when using kto")
|
||||||
|
|
||||||
|
if data.get("gradient_checkpointing") and not (
|
||||||
|
data.get("gradient_checkpointing_kwargs")
|
||||||
|
and isinstance(data.get("gradient_checkpointing_kwargs"), dict)
|
||||||
|
and data["gradient_checkpointing_kwargs"].get("use_reentrant")
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Set `gradient_checkpointing_kwargs: {use_reentrant: true}` for when kto is enabled"
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||||
@@ -1524,19 +1543,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_hopper_8bit_lora(cls, data):
|
|
||||||
is_sm_90: bool = (
|
|
||||||
data["capabilities"]
|
|
||||||
and data["capabilities"].get("compute_capability") == "sm_90"
|
|
||||||
)
|
|
||||||
if data.get("adapter") and data.get("load_in_8bit") and is_sm_90:
|
|
||||||
# see https://github.com/bitsandbytes-foundation/bitsandbytes/issues/538#issuecomment-2262945464
|
|
||||||
raise ValueError("8-bit LoRA is not supported on Hopper GPUs")
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_fsdp_deepspeed(cls, data):
|
def check_fsdp_deepspeed(cls, data):
|
||||||
|
|||||||
@@ -2,11 +2,9 @@
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import requests
|
|
||||||
from datasets import (
|
from datasets import (
|
||||||
Dataset,
|
Dataset,
|
||||||
DatasetDict,
|
DatasetDict,
|
||||||
@@ -44,7 +42,11 @@ from axolotl.prompters import (
|
|||||||
UnsupportedPrompter,
|
UnsupportedPrompter,
|
||||||
)
|
)
|
||||||
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
||||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
from axolotl.utils.data.utils import (
|
||||||
|
deduplicate_and_log_datasets,
|
||||||
|
md5,
|
||||||
|
retry_on_request_exceptions,
|
||||||
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_local_main_process, zero_first
|
from axolotl.utils.distributed import is_local_main_process, zero_first
|
||||||
from axolotl.utils.trainer import (
|
from axolotl.utils.trainer import (
|
||||||
@@ -55,27 +57,6 @@ from axolotl.utils.trainer import (
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
|
||||||
def decorator(func):
|
|
||||||
@functools.wraps(func)
|
|
||||||
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
except (
|
|
||||||
requests.exceptions.ReadTimeout,
|
|
||||||
requests.exceptions.ConnectionError,
|
|
||||||
) as exc:
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
time.sleep(delay)
|
|
||||||
else:
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||||
def prepare_dataset(cfg, tokenizer, processor=None):
|
def prepare_dataset(cfg, tokenizer, processor=None):
|
||||||
prompters = []
|
prompters = []
|
||||||
|
|||||||
@@ -1,13 +1,57 @@
|
|||||||
"""data handling helpers"""
|
"""data handling helpers"""
|
||||||
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import huggingface_hub
|
||||||
|
import requests
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
class RetryStrategy(Enum):
|
||||||
|
"""
|
||||||
|
Enum for retry strategies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
CONSTANT = 1
|
||||||
|
LINEAR = 2
|
||||||
|
EXPONENTIAL = 3
|
||||||
|
|
||||||
|
|
||||||
|
def retry_on_request_exceptions(
|
||||||
|
max_retries=3, delay=1, retry_strategy: RetryStrategy = RetryStrategy.LINEAR
|
||||||
|
):
|
||||||
|
def decorator(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except (
|
||||||
|
requests.exceptions.ReadTimeout,
|
||||||
|
requests.exceptions.ConnectionError,
|
||||||
|
huggingface_hub.errors.HfHubHTTPError,
|
||||||
|
) as exc:
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
if retry_strategy == RetryStrategy.EXPONENTIAL:
|
||||||
|
step_delay = delay * 2**attempt
|
||||||
|
elif retry_strategy == RetryStrategy.LINEAR:
|
||||||
|
step_delay = delay * (attempt + 1)
|
||||||
|
else:
|
||||||
|
step_delay = delay # Use constant delay.
|
||||||
|
time.sleep(step_delay)
|
||||||
|
else:
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||||
try:
|
try:
|
||||||
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
|
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ from transformers import ( # noqa: F401
|
|||||||
AddedToken,
|
AddedToken,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForImageTextToText,
|
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@@ -92,11 +91,7 @@ def get_module_class_from_name(module, name):
|
|||||||
|
|
||||||
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
||||||
if cfg.is_multimodal:
|
if cfg.is_multimodal:
|
||||||
try:
|
model_config = model_config.text_config
|
||||||
model_config = model_config.text_config
|
|
||||||
except AttributeError:
|
|
||||||
# for qwen2_vl
|
|
||||||
model_config = model_config.get_text_config()
|
|
||||||
|
|
||||||
quant_config_exists = (
|
quant_config_exists = (
|
||||||
hasattr(model_config, "quantization_config")
|
hasattr(model_config, "quantization_config")
|
||||||
@@ -372,11 +367,7 @@ class ModelLoader:
|
|||||||
# init model config
|
# init model config
|
||||||
self.model_config = load_model_config(cfg)
|
self.model_config = load_model_config(cfg)
|
||||||
if cfg.is_multimodal:
|
if cfg.is_multimodal:
|
||||||
try:
|
self.text_model_config = self.model_config.text_config
|
||||||
self.text_model_config = self.model_config.text_config
|
|
||||||
except AttributeError:
|
|
||||||
# for qwen2_vl
|
|
||||||
self.text_model_config = self.model_config.get_text_config()
|
|
||||||
else:
|
else:
|
||||||
self.text_model_config = self.model_config
|
self.text_model_config = self.model_config
|
||||||
|
|
||||||
@@ -389,12 +380,43 @@ class ModelLoader:
|
|||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
plugin_manager.pre_model_load(self.cfg)
|
plugin_manager.pre_model_load(self.cfg)
|
||||||
|
|
||||||
|
if self.cfg.model_config_type == "llama":
|
||||||
|
from axolotl.monkeypatch.models.llama.modeling_llama import replace_auto_model
|
||||||
|
|
||||||
|
AxolotlLlamaForCausalLM = replace_auto_model()
|
||||||
|
|
||||||
|
AxolotlLlamaForCausalLM.patch_hf_ga()
|
||||||
|
if self.cfg.activation_offloading:
|
||||||
|
AxolotlLlamaForCausalLM.enable_act_offloading()
|
||||||
|
|
||||||
|
if self.cfg.fsdp:
|
||||||
|
from axolotl.monkeypatch.trainer_fsdp_optim import (
|
||||||
|
patch_training_loop_for_fsdp,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_training_loop_for_fsdp()
|
||||||
|
elif self.cfg.deepspeed and self.cfg.gradient_accumulation_steps > 1:
|
||||||
|
from axolotl.monkeypatch.trainer_grad_accum import (
|
||||||
|
patch_training_loop_for_deepspeed_0_16_x,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_training_loop_for_deepspeed_0_16_x()
|
||||||
|
|
||||||
if self.cfg.gradient_checkpointing == "unsloth":
|
if self.cfg.gradient_checkpointing == "unsloth":
|
||||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
||||||
|
|
||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
self.patch_attention()
|
self.patch_attention()
|
||||||
|
|
||||||
|
if self.cfg.model_config_type == "llama":
|
||||||
|
from axolotl.monkeypatch.trainer_grad_accum import (
|
||||||
|
patch_forward_for_ga,
|
||||||
|
patch_training_step_for_ga,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_forward_for_ga()
|
||||||
|
patch_training_step_for_ga()
|
||||||
|
|
||||||
if self.cfg.sample_packing and self.cfg.s2_attention:
|
if self.cfg.sample_packing and self.cfg.s2_attention:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Received `sample_packing=true` and `s2_attention=true`; however, \
|
"Received `sample_packing=true` and `s2_attention=true`; however, \
|
||||||
@@ -406,10 +428,14 @@ class ModelLoader:
|
|||||||
and self.cfg.flash_attention
|
and self.cfg.flash_attention
|
||||||
and self.cfg.sample_packing
|
and self.cfg.sample_packing
|
||||||
):
|
):
|
||||||
has_remote_code = (
|
if "auto_map" in self.model_config:
|
||||||
"auto_map" in self.model_config
|
try:
|
||||||
and "AutoModelForCausalLM" in self.model_config["auto_map"]
|
auto_map_config = self.model_config["auto_map"]
|
||||||
)
|
except TypeError:
|
||||||
|
auto_map_config = self.model_config.auto_map
|
||||||
|
has_remote_code = "AutoModelForCausalLM" in auto_map_config
|
||||||
|
else:
|
||||||
|
has_remote_code = False
|
||||||
if has_remote_code and self.cfg.trust_remote_code is False:
|
if has_remote_code and self.cfg.trust_remote_code is False:
|
||||||
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
|
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
|
||||||
has_remote_code = self.cfg.trust_remote_code
|
has_remote_code = self.cfg.trust_remote_code
|
||||||
@@ -562,10 +588,6 @@ class ModelLoader:
|
|||||||
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
||||||
MllamaForConditionalGeneration
|
MllamaForConditionalGeneration
|
||||||
)
|
)
|
||||||
elif self.model_config.model_type == "qwen2_vl":
|
|
||||||
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
|
||||||
AutoModelForImageTextToText
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.AutoModelLoader = (
|
self.AutoModelLoader = (
|
||||||
AutoModelForVision2Seq # pylint: disable=invalid-name
|
AutoModelForVision2Seq # pylint: disable=invalid-name
|
||||||
@@ -1058,9 +1080,7 @@ class ModelLoader:
|
|||||||
and self.model.get_input_embeddings().num_embeddings < embeddings_len
|
and self.model.get_input_embeddings().num_embeddings < embeddings_len
|
||||||
):
|
):
|
||||||
resize_kwargs = {}
|
resize_kwargs = {}
|
||||||
if self.cfg.mean_resizing_embeddings is not None and not (
|
if self.cfg.mean_resizing_embeddings is not None:
|
||||||
self.model_config.model_type == "llava"
|
|
||||||
):
|
|
||||||
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
|
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
|
||||||
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
||||||
else:
|
else:
|
||||||
@@ -1172,6 +1192,8 @@ class ModelLoader:
|
|||||||
|
|
||||||
self.apply_lora_patch()
|
self.apply_lora_patch()
|
||||||
|
|
||||||
|
# self.apply_patches_to_model()
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|||||||
0
src/axolotl/utils/optimizers/__init__.py
Normal file
0
src/axolotl/utils/optimizers/__init__.py
Normal file
104
src/setuptools_axolotl_dynamic_dependencies.py
Normal file
104
src/setuptools_axolotl_dynamic_dependencies.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""
|
||||||
|
dynamic requirements for axolotl
|
||||||
|
"""
|
||||||
|
import platform
|
||||||
|
import re
|
||||||
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
|
||||||
|
from setuptools.command.build_py import build_py as _build_py
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
def parse_requirements():
|
||||||
|
_install_requires = []
|
||||||
|
_dependency_links = []
|
||||||
|
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||||
|
lines = [r.strip() for r in requirements_file.readlines()]
|
||||||
|
for line in lines:
|
||||||
|
is_extras = (
|
||||||
|
"flash-attn" in line
|
||||||
|
or "flash-attention" in line
|
||||||
|
or "deepspeed" in line
|
||||||
|
or "mamba-ssm" in line
|
||||||
|
or "lion-pytorch" in line
|
||||||
|
)
|
||||||
|
if line.startswith("--extra-index-url"):
|
||||||
|
# Handle custom index URLs
|
||||||
|
_, url = line.split()
|
||||||
|
_dependency_links.append(url)
|
||||||
|
elif not is_extras and line and line[0] != "#":
|
||||||
|
# Handle standard packages
|
||||||
|
_install_requires.append(line)
|
||||||
|
|
||||||
|
try:
|
||||||
|
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||||
|
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
||||||
|
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
||||||
|
|
||||||
|
if "Darwin" in platform.system():
|
||||||
|
# don't install xformers on MacOS
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
else:
|
||||||
|
# detect the version of torch already installed
|
||||||
|
# and set it so dependencies don't clobber the torch version
|
||||||
|
try:
|
||||||
|
torch_version = version("torch")
|
||||||
|
except PackageNotFoundError:
|
||||||
|
torch_version = "2.5.1"
|
||||||
|
_install_requires.append(f"torch=={torch_version}")
|
||||||
|
|
||||||
|
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||||
|
if version_match:
|
||||||
|
major, minor, patch = version_match.groups()
|
||||||
|
major, minor = int(major), int(minor)
|
||||||
|
patch = (
|
||||||
|
int(patch) if patch is not None else 0
|
||||||
|
) # Default patch to 0 if not present
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
|
if (major, minor) >= (2, 5):
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
if patch == 0:
|
||||||
|
_install_requires.append("xformers==0.0.28.post2")
|
||||||
|
else:
|
||||||
|
_install_requires.append("xformers==0.0.28.post3")
|
||||||
|
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||||
|
elif (major, minor) >= (2, 4):
|
||||||
|
if patch == 0:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.27")
|
||||||
|
else:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers==0.0.28.post1")
|
||||||
|
elif (major, minor) >= (2, 3):
|
||||||
|
_install_requires.pop(_install_requires.index(torchao_version))
|
||||||
|
if patch == 0:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.26.post1")
|
||||||
|
else:
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.27")
|
||||||
|
elif (major, minor) >= (2, 2):
|
||||||
|
_install_requires.pop(_install_requires.index(torchao_version))
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.25.post1")
|
||||||
|
else:
|
||||||
|
_install_requires.pop(_install_requires.index(torchao_version))
|
||||||
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
_install_requires.append("xformers>=0.0.23.post1")
|
||||||
|
|
||||||
|
except PackageNotFoundError:
|
||||||
|
pass
|
||||||
|
return _install_requires, _dependency_links
|
||||||
|
|
||||||
|
|
||||||
|
class BuildPyCommand(_build_py):
|
||||||
|
"""
|
||||||
|
custom build_py command to parse dynamic requirements
|
||||||
|
"""
|
||||||
|
|
||||||
|
def finalize_options(self):
|
||||||
|
super().finalize_options()
|
||||||
|
install_requires, _ = parse_requirements()
|
||||||
|
self.distribution.install_requires = install_requires
|
||||||
0
tests/cli/__init__.py
Normal file
0
tests/cli/__init__.py
Normal file
36
tests/cli/conftest.py
Normal file
36
tests/cli/conftest.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""Shared pytest fixtures for cli module."""
|
||||||
|
import pytest
|
||||||
|
from click.testing import CliRunner
|
||||||
|
|
||||||
|
VALID_TEST_CONFIG = """
|
||||||
|
base_model: HuggingFaceTB/SmolLM2-135M
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
sequence_len: 2048
|
||||||
|
max_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
learning_rate: 1e-3
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|endoftext|>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cli_runner():
|
||||||
|
return CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def valid_test_config():
|
||||||
|
return VALID_TEST_CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config_path(tmp_path):
|
||||||
|
"""Creates a temporary config file"""
|
||||||
|
path = tmp_path / "config.yml"
|
||||||
|
path.write_text(VALID_TEST_CONFIG)
|
||||||
|
|
||||||
|
return path
|
||||||
38
tests/cli/test_cli_fetch.py
Normal file
38
tests/cli/test_cli_fetch.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""pytest tests for axolotl CLI fetch command."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.cli.main import fetch
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_cli_examples(cli_runner):
|
||||||
|
"""Test fetch command with examples directory"""
|
||||||
|
with patch("axolotl.cli.main.fetch_from_github") as mock_fetch:
|
||||||
|
result = cli_runner.invoke(fetch, ["examples"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
mock_fetch.assert_called_once_with("examples/", None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_cli_deepspeed(cli_runner):
|
||||||
|
"""Test fetch command with deepspeed_configs directory"""
|
||||||
|
with patch("axolotl.cli.main.fetch_from_github") as mock_fetch:
|
||||||
|
result = cli_runner.invoke(fetch, ["deepspeed_configs"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
mock_fetch.assert_called_once_with("deepspeed_configs/", None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_cli_with_dest(cli_runner, tmp_path):
|
||||||
|
"""Test fetch command with custom destination"""
|
||||||
|
with patch("axolotl.cli.main.fetch_from_github") as mock_fetch:
|
||||||
|
custom_dir = tmp_path / "tmp_examples"
|
||||||
|
result = cli_runner.invoke(fetch, ["examples", "--dest", str(custom_dir)])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
mock_fetch.assert_called_once_with("examples/", str(custom_dir))
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_cli_invalid_directory(cli_runner):
|
||||||
|
"""Test fetch command with invalid directory choice"""
|
||||||
|
result = cli_runner.invoke(fetch, ["invalid"])
|
||||||
|
assert result.exit_code != 0
|
||||||
30
tests/cli/test_cli_inference.py
Normal file
30
tests/cli/test_cli_inference.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""pytest tests for axolotl CLI inference command."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.cli.main import cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_inference_basic(cli_runner, config_path):
|
||||||
|
"""Test basic inference"""
|
||||||
|
with patch("axolotl.cli.inference.do_inference") as mock:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
["inference", str(config_path), "--no-accelerate"],
|
||||||
|
catch_exceptions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_inference_gradio(cli_runner, config_path):
|
||||||
|
"""Test basic inference (gradio path)"""
|
||||||
|
with patch("axolotl.cli.inference.do_inference_gradio") as mock:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
["inference", str(config_path), "--no-accelerate", "--gradio"],
|
||||||
|
catch_exceptions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert result.exit_code == 0
|
||||||
47
tests/cli/test_cli_interface.py
Normal file
47
tests/cli/test_cli_interface.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""General pytest tests for axolotl.cli.main interface."""
|
||||||
|
from axolotl.cli.main import build_command, cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_command():
|
||||||
|
"""Test converting dict of options to CLI arguments"""
|
||||||
|
base_cmd = ["accelerate", "launch"]
|
||||||
|
options = {
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"batch_size": 8,
|
||||||
|
"debug": True,
|
||||||
|
"use_fp16": False,
|
||||||
|
"null_value": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = build_command(base_cmd, options)
|
||||||
|
assert result == [
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"--learning-rate",
|
||||||
|
"0.0001",
|
||||||
|
"--batch-size",
|
||||||
|
"8",
|
||||||
|
"--debug",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_command_options(cli_runner):
|
||||||
|
"""Test handling of invalid command options"""
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"train",
|
||||||
|
"config.yml",
|
||||||
|
"--invalid-option",
|
||||||
|
"value",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "No such option" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_required_config_argument(cli_runner):
|
||||||
|
"""Test commands fail properly when config argument is missing"""
|
||||||
|
result = cli_runner.invoke(cli, ["train"])
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "Missing argument 'CONFIG'" in result.output
|
||||||
56
tests/cli/test_cli_merge_lora.py
Normal file
56
tests/cli/test_cli_merge_lora.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""pytest tests for axolotl CLI merge_lora command."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.cli.main import cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_lora_basic(cli_runner, config_path):
|
||||||
|
"""Test basic merge_lora command"""
|
||||||
|
with patch("axolotl.cli.merge_lora.do_cli") as mock_do_cli:
|
||||||
|
result = cli_runner.invoke(cli, ["merge-lora", str(config_path)])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
mock_do_cli.assert_called_once()
|
||||||
|
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_lora_with_dirs(cli_runner, config_path, tmp_path):
|
||||||
|
"""Test merge_lora with custom lora and output directories"""
|
||||||
|
lora_dir = tmp_path / "lora"
|
||||||
|
output_dir = tmp_path / "output"
|
||||||
|
lora_dir.mkdir()
|
||||||
|
|
||||||
|
with patch("axolotl.cli.merge_lora.do_cli") as mock_do_cli:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"merge-lora",
|
||||||
|
str(config_path),
|
||||||
|
"--lora-model-dir",
|
||||||
|
str(lora_dir),
|
||||||
|
"--output-dir",
|
||||||
|
str(output_dir),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
mock_do_cli.assert_called_once()
|
||||||
|
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||||
|
assert mock_do_cli.call_args.kwargs["lora_model_dir"] == str(lora_dir)
|
||||||
|
assert mock_do_cli.call_args.kwargs["output_dir"] == str(output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_lora_nonexistent_config(cli_runner, tmp_path):
|
||||||
|
"""Test merge_lora with nonexistent config"""
|
||||||
|
config_path = tmp_path / "nonexistent.yml"
|
||||||
|
result = cli_runner.invoke(cli, ["merge-lora", str(config_path)])
|
||||||
|
assert result.exit_code != 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_lora_nonexistent_lora_dir(cli_runner, config_path, tmp_path):
|
||||||
|
"""Test merge_lora with nonexistent lora directory"""
|
||||||
|
lora_dir = tmp_path / "nonexistent"
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli, ["merge-lora", str(config_path), "--lora-model-dir", str(lora_dir)]
|
||||||
|
)
|
||||||
|
assert result.exit_code != 0
|
||||||
60
tests/cli/test_cli_merge_sharded_fsdp_weights.py
Normal file
60
tests/cli/test_cli_merge_sharded_fsdp_weights.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.cli.main import cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):
|
||||||
|
"""Test merge_sharded_fsdp_weights command without accelerate"""
|
||||||
|
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli, ["merge-sharded-fsdp-weights", str(config_path), "--no-accelerate"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_sharded_fsdp_weights_with_model_dir(cli_runner, config_path, tmp_path):
|
||||||
|
"""Test merge_sharded_fsdp_weights command with model_dir option"""
|
||||||
|
model_dir = tmp_path / "model"
|
||||||
|
model_dir.mkdir()
|
||||||
|
|
||||||
|
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"merge-sharded-fsdp-weights",
|
||||||
|
str(config_path),
|
||||||
|
"--no-accelerate",
|
||||||
|
"--model-dir",
|
||||||
|
str(model_dir),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||||
|
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_sharded_fsdp_weights_with_save_path(cli_runner, config_path):
|
||||||
|
"""Test merge_sharded_fsdp_weights command with save_path option"""
|
||||||
|
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"merge-sharded-fsdp-weights",
|
||||||
|
str(config_path),
|
||||||
|
"--no-accelerate",
|
||||||
|
"--save-path",
|
||||||
|
"/path/to/save",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||||
|
assert mock.call_args.kwargs["save_path"] == "/path/to/save"
|
||||||
|
assert result.exit_code == 0
|
||||||
71
tests/cli/test_cli_preprocess.py
Normal file
71
tests/cli/test_cli_preprocess.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""pytest tests for axolotl CLI preprocess command."""
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.cli.main import cli
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def cleanup_last_run_prepared():
|
||||||
|
yield
|
||||||
|
|
||||||
|
if Path("last_run_prepared").exists():
|
||||||
|
shutil.rmtree("last_run_prepared")
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocess_config_not_found(cli_runner):
|
||||||
|
"""Test preprocess fails when config not found"""
|
||||||
|
result = cli_runner.invoke(cli, ["preprocess", "nonexistent.yml"])
|
||||||
|
assert result.exit_code != 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocess_basic(cli_runner, config_path):
|
||||||
|
"""Test basic preprocessing with minimal config"""
|
||||||
|
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
|
||||||
|
result = cli_runner.invoke(cli, ["preprocess", str(config_path)])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
mock_do_cli.assert_called_once()
|
||||||
|
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||||
|
assert mock_do_cli.call_args.kwargs["download"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocess_without_download(cli_runner, config_path):
|
||||||
|
"""Test preprocessing without model download"""
|
||||||
|
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli, ["preprocess", str(config_path), "--no-download"]
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
mock_do_cli.assert_called_once()
|
||||||
|
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||||
|
assert mock_do_cli.call_args.kwargs["download"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocess_custom_path(cli_runner, tmp_path, valid_test_config):
|
||||||
|
"""Test preprocessing with custom dataset path"""
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
custom_path = tmp_path / "custom_prepared"
|
||||||
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
|
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"preprocess",
|
||||||
|
str(config_path),
|
||||||
|
"--dataset-prepared-path",
|
||||||
|
str(custom_path.absolute()),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
mock_do_cli.assert_called_once()
|
||||||
|
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||||
|
assert mock_do_cli.call_args.kwargs["dataset_prepared_path"] == str(
|
||||||
|
custom_path.absolute()
|
||||||
|
)
|
||||||
76
tests/cli/test_cli_shard.py
Normal file
76
tests/cli/test_cli_shard.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""pytest tests for axolotl CLI shard command."""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.cli.main import cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_shard_with_accelerate(cli_runner, config_path):
|
||||||
|
"""Test shard command with accelerate"""
|
||||||
|
with patch("subprocess.run") as mock:
|
||||||
|
result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"])
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert mock.call_args.args[0] == [
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"-m",
|
||||||
|
"axolotl.cli.shard",
|
||||||
|
str(config_path),
|
||||||
|
"--debug-num-examples",
|
||||||
|
"0",
|
||||||
|
]
|
||||||
|
assert mock.call_args.kwargs == {"check": True}
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_shard_no_accelerate(cli_runner, config_path):
|
||||||
|
"""Test shard command without accelerate"""
|
||||||
|
with patch("axolotl.cli.shard.do_cli") as mock:
|
||||||
|
result = cli_runner.invoke(cli, ["shard", str(config_path), "--no-accelerate"])
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_shard_with_model_dir(cli_runner, config_path, tmp_path):
|
||||||
|
"""Test shard command with model_dir option"""
|
||||||
|
model_dir = tmp_path / "model"
|
||||||
|
model_dir.mkdir()
|
||||||
|
|
||||||
|
with patch("axolotl.cli.shard.do_cli") as mock:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"shard",
|
||||||
|
str(config_path),
|
||||||
|
"--no-accelerate",
|
||||||
|
"--model-dir",
|
||||||
|
str(model_dir),
|
||||||
|
],
|
||||||
|
catch_exceptions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||||
|
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_shard_with_save_dir(cli_runner, config_path):
|
||||||
|
with patch("axolotl.cli.shard.do_cli") as mock:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"shard",
|
||||||
|
str(config_path),
|
||||||
|
"--no-accelerate",
|
||||||
|
"--save-dir",
|
||||||
|
"/path/to/save",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||||
|
assert mock.call_args.kwargs["save_dir"] == "/path/to/save"
|
||||||
|
assert result.exit_code == 0
|
||||||
98
tests/cli/test_cli_train.py
Normal file
98
tests/cli/test_cli_train.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""pytest tests for axolotl CLI train command."""
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from axolotl.cli.main import cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_cli_validation(cli_runner):
|
||||||
|
"""Test CLI validation"""
|
||||||
|
# Test missing config file
|
||||||
|
result = cli_runner.invoke(cli, ["train", "--no-accelerate"])
|
||||||
|
assert result.exit_code != 0
|
||||||
|
|
||||||
|
# Test non-existent config file
|
||||||
|
result = cli_runner.invoke(cli, ["train", "nonexistent.yml", "--no-accelerate"])
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_basic_execution(cli_runner, tmp_path, valid_test_config):
|
||||||
|
"""Test basic successful execution"""
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
|
with patch("subprocess.run") as mock:
|
||||||
|
result = cli_runner.invoke(cli, ["train", str(config_path)])
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert mock.call_args.args[0] == [
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"-m",
|
||||||
|
"axolotl.cli.train",
|
||||||
|
str(config_path),
|
||||||
|
"--debug-num-examples",
|
||||||
|
"0",
|
||||||
|
]
|
||||||
|
assert mock.call_args.kwargs == {"check": True}
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_basic_execution_no_accelerate(cli_runner, tmp_path, valid_test_config):
|
||||||
|
"""Test basic successful execution"""
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
|
with patch("axolotl.cli.train.train") as mock_train:
|
||||||
|
mock_train.return_value = (MagicMock(), MagicMock())
|
||||||
|
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"train",
|
||||||
|
str(config_path),
|
||||||
|
"--learning-rate",
|
||||||
|
"1e-4",
|
||||||
|
"--micro-batch-size",
|
||||||
|
"2",
|
||||||
|
"--no-accelerate",
|
||||||
|
],
|
||||||
|
catch_exceptions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
mock_train.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_cli_overrides(cli_runner, tmp_path, valid_test_config):
|
||||||
|
"""Test CLI arguments properly override config values"""
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
output_dir = tmp_path / "model-out"
|
||||||
|
|
||||||
|
test_config = valid_test_config.replace(
|
||||||
|
"output_dir: model-out", f"output_dir: {output_dir}"
|
||||||
|
)
|
||||||
|
config_path.write_text(test_config)
|
||||||
|
|
||||||
|
with patch("axolotl.cli.train.train") as mock_train:
|
||||||
|
mock_train.return_value = (MagicMock(), MagicMock())
|
||||||
|
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"train",
|
||||||
|
str(config_path),
|
||||||
|
"--learning-rate",
|
||||||
|
"1e-4",
|
||||||
|
"--micro-batch-size",
|
||||||
|
"2",
|
||||||
|
"--no-accelerate",
|
||||||
|
],
|
||||||
|
catch_exceptions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
mock_train.assert_called_once()
|
||||||
|
cfg = mock_train.call_args[1]["cfg"]
|
||||||
|
assert cfg["learning_rate"] == 1e-4
|
||||||
|
assert cfg["micro_batch_size"] == 2
|
||||||
10
tests/cli/test_cli_version.py
Normal file
10
tests/cli/test_cli_version.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""pytest tests for axolotl CLI --version"""
|
||||||
|
from axolotl.cli.main import cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_print_version(cli_runner):
|
||||||
|
"""Test that version is printed when --version is used."""
|
||||||
|
|
||||||
|
result = cli_runner.invoke(cli, ["--version"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "axolotl, version " in result.output
|
||||||
72
tests/cli/test_utils.py
Normal file
72
tests/cli/test_utils.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""pytest tests for axolotl CLI utils."""
|
||||||
|
# pylint: disable=redefined-outer-name
|
||||||
|
import json
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import click
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from axolotl.cli.utils import fetch_from_github
|
||||||
|
|
||||||
|
# Sample GitHub API response
|
||||||
|
MOCK_TREE_RESPONSE = {
|
||||||
|
"tree": [
|
||||||
|
{"path": "examples/config1.yml", "type": "blob", "sha": "abc123"},
|
||||||
|
{"path": "examples/config2.yml", "type": "blob", "sha": "def456"},
|
||||||
|
{"path": "other/file.txt", "type": "blob", "sha": "xyz789"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_responses():
|
||||||
|
"""Mock responses for API and file downloads"""
|
||||||
|
|
||||||
|
def mock_get(url, timeout=None): # pylint: disable=unused-argument
|
||||||
|
response = Mock()
|
||||||
|
if "api.github.com" in url:
|
||||||
|
response.text = json.dumps(MOCK_TREE_RESPONSE)
|
||||||
|
else:
|
||||||
|
response.content = b"file content"
|
||||||
|
return response
|
||||||
|
|
||||||
|
return mock_get
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_from_github_new_files(tmp_path, mock_responses):
|
||||||
|
"""Test fetching new files"""
|
||||||
|
with patch("requests.get", mock_responses):
|
||||||
|
fetch_from_github("examples/", tmp_path)
|
||||||
|
|
||||||
|
# Verify files were created
|
||||||
|
assert (tmp_path / "config1.yml").exists()
|
||||||
|
assert (tmp_path / "config2.yml").exists()
|
||||||
|
assert not (tmp_path / "file.txt").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_from_github_unchanged_files(tmp_path, mock_responses):
|
||||||
|
"""Test handling of unchanged files"""
|
||||||
|
# Create existing file with matching SHA
|
||||||
|
existing_file = tmp_path / "config1.yml"
|
||||||
|
existing_file.write_bytes(b"file content")
|
||||||
|
|
||||||
|
with patch("requests.get", mock_responses):
|
||||||
|
fetch_from_github("examples/", tmp_path)
|
||||||
|
|
||||||
|
# File should not be downloaded again
|
||||||
|
assert existing_file.read_bytes() == b"file content"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_from_github_invalid_prefix(mock_responses):
|
||||||
|
"""Test error handling for invalid directory prefix"""
|
||||||
|
with patch("requests.get", mock_responses):
|
||||||
|
with pytest.raises(click.ClickException):
|
||||||
|
fetch_from_github("nonexistent/", None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_from_github_network_error():
|
||||||
|
"""Test handling of network errors"""
|
||||||
|
with patch("requests.get", side_effect=requests.RequestException):
|
||||||
|
with pytest.raises(requests.RequestException):
|
||||||
|
fetch_from_github("examples/", None)
|
||||||
@@ -1,68 +1,109 @@
|
|||||||
"""
|
"""
|
||||||
shared pytest fixtures
|
shared pytest fixtures
|
||||||
"""
|
"""
|
||||||
|
import functools
|
||||||
|
import importlib
|
||||||
import shutil
|
import shutil
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import time
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import requests
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
def decorator(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except (
|
||||||
|
requests.exceptions.ReadTimeout,
|
||||||
|
requests.exceptions.ConnectionError,
|
||||||
|
) as exc:
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
time.sleep(delay)
|
||||||
|
else:
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||||
|
def snapshot_download_w_retry(*args, **kwargs):
|
||||||
|
return snapshot_download(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_smollm2_135m_model():
|
def download_smollm2_135m_model():
|
||||||
# download the model
|
# download the model
|
||||||
snapshot_download("HuggingFaceTB/SmolLM2-135M")
|
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_llama_68m_random_model():
|
def download_llama_68m_random_model():
|
||||||
# download the model
|
# download the model
|
||||||
snapshot_download("JackFram/llama-68m")
|
snapshot_download_w_retry("JackFram/llama-68m")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_qwen_2_5_half_billion_model():
|
def download_qwen_2_5_half_billion_model():
|
||||||
# download the model
|
# download the model
|
||||||
snapshot_download("Qwen/Qwen2.5-0.5B")
|
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_tatsu_lab_alpaca_dataset():
|
def download_tatsu_lab_alpaca_dataset():
|
||||||
# download the dataset
|
# download the dataset
|
||||||
snapshot_download("tatsu-lab/alpaca", repo_type="dataset")
|
snapshot_download_w_retry("tatsu-lab/alpaca", repo_type="dataset")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_mhenrichsen_alpaca_2k_dataset():
|
def download_mhenrichsen_alpaca_2k_dataset():
|
||||||
# download the dataset
|
# download the dataset
|
||||||
snapshot_download("mhenrichsen/alpaca_2k_test", repo_type="dataset")
|
snapshot_download_w_retry("mhenrichsen/alpaca_2k_test", repo_type="dataset")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_mhenrichsen_alpaca_2k_w_revision_dataset():
|
def download_mhenrichsen_alpaca_2k_w_revision_dataset():
|
||||||
# download the dataset
|
# download the dataset
|
||||||
snapshot_download(
|
snapshot_download_w_retry(
|
||||||
"mhenrichsen/alpaca_2k_test", repo_type="dataset", revision="d05c1cb"
|
"mhenrichsen/alpaca_2k_test", repo_type="dataset", revision="d05c1cb"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_mlabonne_finetome_100k_dataset():
|
def download_mlabonne_finetome_100k_dataset():
|
||||||
# download the dataset
|
# download the dataset
|
||||||
snapshot_download("mlabonne/FineTome-100k", repo_type="dataset")
|
snapshot_download_w_retry("mlabonne/FineTome-100k", repo_type="dataset")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
|
def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
|
||||||
# download the dataset
|
# download the dataset
|
||||||
snapshot_download(
|
snapshot_download_w_retry(
|
||||||
"argilla/distilabel-capybara-dpo-7k-binarized", repo_type="dataset"
|
"argilla/distilabel-capybara-dpo-7k-binarized", repo_type="dataset"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
|
||||||
|
# download the dataset
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
||||||
# download the dataset
|
# download the dataset
|
||||||
snapshot_download(
|
snapshot_download_w_retry(
|
||||||
"arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", repo_type="dataset"
|
"arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", repo_type="dataset"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -74,3 +115,57 @@ def temp_dir():
|
|||||||
yield _temp_dir
|
yield _temp_dir
|
||||||
# Clean up the directory after the test
|
# Clean up the directory after the test
|
||||||
shutil.rmtree(_temp_dir)
|
shutil.rmtree(_temp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def cleanup_monkeypatches():
|
||||||
|
from transformers import Trainer
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaAttention,
|
||||||
|
LlamaFlashAttention2,
|
||||||
|
LlamaForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
original_fa2_forward = LlamaFlashAttention2.forward
|
||||||
|
original_llama_attn_forward = LlamaAttention.forward
|
||||||
|
original_llama_forward = LlamaForCausalLM.forward
|
||||||
|
original_trainer_inner_training_loop = (
|
||||||
|
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
original_trainer_training_step = Trainer.training_step
|
||||||
|
# monkey patches can happen inside the tests
|
||||||
|
yield
|
||||||
|
# Reset LlamaFlashAttention2 forward
|
||||||
|
LlamaFlashAttention2.forward = original_fa2_forward
|
||||||
|
LlamaAttention.forward = original_llama_attn_forward
|
||||||
|
LlamaForCausalLM.forward = original_llama_forward
|
||||||
|
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||||
|
original_trainer_inner_training_loop
|
||||||
|
)
|
||||||
|
Trainer.training_step = original_trainer_training_step
|
||||||
|
|
||||||
|
# Reset other known monkeypatches
|
||||||
|
modules_to_reset: list[tuple[str, list[str]]] = [
|
||||||
|
("transformers.models.llama",),
|
||||||
|
(
|
||||||
|
"transformers.models.llama.modeling_llama",
|
||||||
|
["LlamaFlashAttention2", "LlamaAttention"],
|
||||||
|
),
|
||||||
|
("transformers.trainer",),
|
||||||
|
("transformers", ["Trainer"]),
|
||||||
|
("transformers.loss.loss_utils",),
|
||||||
|
]
|
||||||
|
for module_name_tuple in modules_to_reset:
|
||||||
|
module_name = module_name_tuple[0]
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
module_name, sys.modules[module_name].__file__
|
||||||
|
)
|
||||||
|
sys.modules[module_name] = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(sys.modules[module_name])
|
||||||
|
|
||||||
|
sys.modules[module_name] = importlib.reload(sys.modules[module_name])
|
||||||
|
if len(module_name_tuple) > 1:
|
||||||
|
module_globals = module_name_tuple[1]
|
||||||
|
for module_global in module_globals:
|
||||||
|
globals().pop(module_global, None)
|
||||||
|
|||||||
@@ -71,7 +71,11 @@ class TestCutCrossEntropyIntegration:
|
|||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"attention_type",
|
"attention_type",
|
||||||
["flash_attention", "sdp_attention", "xformers_attention"],
|
[
|
||||||
|
"flash_attention",
|
||||||
|
"sdp_attention",
|
||||||
|
# "xformers_attention",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type):
|
def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
|
|||||||
@@ -7,12 +7,11 @@ from pathlib import Path
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
from tbparse import SummaryReader
|
|
||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import most_recent_subdir
|
from ..utils import check_tensorboard
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -91,12 +90,8 @@ class TestMultiGPUEval:
|
|||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
|
||||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
check_tensorboard(temp_dir + "/runs", "eval/loss", 2.5, "Eval Loss is too high")
|
||||||
reader = SummaryReader(event_file)
|
|
||||||
df = reader.scalars # pylint: disable=invalid-name
|
|
||||||
df = df[(df.tag == "eval/loss")] # pylint: disable=invalid-name
|
|
||||||
assert df.value.values[-1] < 2.5, "Loss is too high"
|
|
||||||
|
|
||||||
def test_eval(self, temp_dir):
|
def test_eval(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
@@ -164,9 +159,5 @@ class TestMultiGPUEval:
|
|||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
|
||||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
check_tensorboard(temp_dir + "/runs", "eval/loss", 2.9, "Eval Loss is too high")
|
||||||
reader = SummaryReader(event_file)
|
|
||||||
df = reader.scalars # pylint: disable=invalid-name
|
|
||||||
df = df[(df.tag == "eval/loss")] # pylint: disable=invalid-name
|
|
||||||
assert df.value.values[-1] < 2.9, "Loss is too high"
|
|
||||||
|
|||||||
@@ -9,13 +9,12 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from e2e.utils import check_tensorboard
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import is_hopper
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
@@ -55,7 +54,7 @@ class TestMultiGPULlama:
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 4,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
@@ -63,6 +62,7 @@ class TestMultiGPULlama:
|
|||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
|
"use_tensorboard": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -85,9 +85,13 @@ class TestMultiGPULlama:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"gradient_accumulation_steps",
|
"gradient_accumulation_steps",
|
||||||
[1, 4],
|
[1, 2],
|
||||||
)
|
)
|
||||||
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
|
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
@@ -114,14 +118,15 @@ class TestMultiGPULlama:
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
|
"use_tensorboard": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -144,7 +149,10 @@ class TestMultiGPULlama:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||||
|
)
|
||||||
|
|
||||||
def test_dpo_lora_ddp(self, temp_dir):
|
def test_dpo_lora_ddp(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -183,7 +191,7 @@ class TestMultiGPULlama:
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 4,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
@@ -192,6 +200,7 @@ class TestMultiGPULlama:
|
|||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
|
"use_tensorboard": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -214,6 +223,10 @@ class TestMultiGPULlama:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||||
|
)
|
||||||
|
|
||||||
def test_dpo_qlora_ddp(self, temp_dir):
|
def test_dpo_qlora_ddp(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -252,8 +265,8 @@ class TestMultiGPULlama:
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 4,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"warmup_steps": 0,
|
"warmup_steps": 0,
|
||||||
@@ -261,6 +274,7 @@ class TestMultiGPULlama:
|
|||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
|
"use_tensorboard": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -283,9 +297,13 @@ class TestMultiGPULlama:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"gradient_accumulation_steps",
|
"gradient_accumulation_steps",
|
||||||
[1, 4],
|
[1, 2],
|
||||||
)
|
)
|
||||||
def test_fsdp(self, temp_dir, gradient_accumulation_steps):
|
def test_fsdp(self, temp_dir, gradient_accumulation_steps):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
@@ -304,8 +322,8 @@ class TestMultiGPULlama:
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 10,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
@@ -326,6 +344,7 @@ class TestMultiGPULlama:
|
|||||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
},
|
},
|
||||||
|
"use_tensorboard": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -348,6 +367,10 @@ class TestMultiGPULlama:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"fsdp_state_dict_type",
|
"fsdp_state_dict_type",
|
||||||
["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
|
["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
|
||||||
@@ -371,7 +394,7 @@ class TestMultiGPULlama:
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 4,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
@@ -393,6 +416,7 @@ class TestMultiGPULlama:
|
|||||||
"fsdp_state_dict_type": fsdp_state_dict_type,
|
"fsdp_state_dict_type": fsdp_state_dict_type,
|
||||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
},
|
},
|
||||||
|
"use_tensorboard": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -415,6 +439,10 @@ class TestMultiGPULlama:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||||
|
)
|
||||||
|
|
||||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -447,7 +475,7 @@ class TestMultiGPULlama:
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": 4,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
@@ -469,6 +497,7 @@ class TestMultiGPULlama:
|
|||||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
},
|
},
|
||||||
|
"use_tensorboard": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -491,12 +520,41 @@ class TestMultiGPULlama:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"gradient_accumulation_steps",
|
"gradient_accumulation_steps",
|
||||||
[1, 4],
|
[1, 2],
|
||||||
)
|
)
|
||||||
def test_ds_zero3_packed(self, temp_dir, gradient_accumulation_steps):
|
@pytest.mark.parametrize(
|
||||||
|
"deepspeed",
|
||||||
|
[
|
||||||
|
"deepspeed_configs/zero3_bf16.json",
|
||||||
|
"deepspeed_configs/zero3_bf16_cpuoffload_all.json",
|
||||||
|
# "deepspeed_configs/zero3_bf16_cpuoffload_params.json",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"qlora",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
def test_ds_zero3_packed(
|
||||||
|
self, temp_dir, gradient_accumulation_steps, deepspeed, qlora
|
||||||
|
):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
if qlora:
|
||||||
|
adapter = {
|
||||||
|
"adapter": "qlora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"load_in_4bit": True,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
adapter = {}
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
@@ -514,15 +572,17 @@ class TestMultiGPULlama:
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
|
"deepspeed": str(AXOLOTL_ROOT / deepspeed),
|
||||||
|
"use_tensorboard": True,
|
||||||
|
**adapter,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -545,19 +605,35 @@ class TestMultiGPULlama:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_ds_zero3_qlora_packed(self, temp_dir):
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"gradient_accumulation_steps",
|
||||||
|
[1, 2],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"qlora",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps, qlora):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
if qlora:
|
||||||
{
|
adapter = {
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
|
"load_in_4bit": True,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
adapter = {}
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"eval_sample_packing": False,
|
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
@@ -571,15 +647,17 @@ class TestMultiGPULlama:
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 2,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.0001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
|
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
|
||||||
|
"use_tensorboard": True,
|
||||||
|
**adapter,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -601,3 +679,82 @@ class TestMultiGPULlama:
|
|||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"gradient_accumulation_steps",
|
||||||
|
[1, 2],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"qlora",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
def test_ds_zero1_packed(self, temp_dir, gradient_accumulation_steps, qlora):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
if qlora:
|
||||||
|
adapter = {
|
||||||
|
"adapter": "qlora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"load_in_4bit": True,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
adapter = {}
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"sample_packing": True,
|
||||||
|
"pad_to_sequence_len": True,
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0.05,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 2,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
|
||||||
|
"use_tensorboard": True,
|
||||||
|
**adapter,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
"-m",
|
||||||
|
"axolotl.cli.train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||||
|
)
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
|||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.02,
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
@@ -86,7 +86,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
|||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.02,
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
|||||||
47
tests/e2e/patched/test_cli_integrations.py
Normal file
47
tests/e2e/patched/test_cli_integrations.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""
|
||||||
|
test cases to make sure the plugin args are loaded from the config file
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from axolotl.cli import load_cfg
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
class TestPluginArgs:
|
||||||
|
"""
|
||||||
|
test class for plugin args loaded from the config file
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_liger_plugin_args(self, temp_dir):
|
||||||
|
test_cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"learning_rate": 0.000001,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"plugins": ["axolotl.integrations.liger.LigerPlugin"],
|
||||||
|
"liger_layer_norm": True,
|
||||||
|
"liger_rope": True,
|
||||||
|
"liger_rms_norm": False,
|
||||||
|
"liger_glu_activation": True,
|
||||||
|
"liger_fused_linear_cross_entropy": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(test_cfg.to_dict()))
|
||||||
|
cfg = load_cfg(str(Path(temp_dir) / "config.yaml"))
|
||||||
|
assert cfg.liger_layer_norm is True
|
||||||
|
assert cfg.liger_rope is True
|
||||||
|
assert cfg.liger_rms_norm is False
|
||||||
|
assert cfg.liger_glu_activation is True
|
||||||
|
assert cfg.liger_fused_linear_cross_entropy is True
|
||||||
@@ -4,11 +4,9 @@ E2E tests for lora llama
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from importlib import reload
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from tbparse import SummaryReader
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
@@ -17,20 +15,12 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import most_recent_subdir
|
from ..utils import check_tensorboard
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def reload_transformers():
|
|
||||||
import transformers.models.llama.modeling_llama
|
|
||||||
|
|
||||||
yield
|
|
||||||
reload(transformers.models.llama.modeling_llama)
|
|
||||||
|
|
||||||
|
|
||||||
class TestFAXentropyLlama:
|
class TestFAXentropyLlama:
|
||||||
"""
|
"""
|
||||||
Test case for Llama models using LoRA w multipack
|
Test case for Llama models using LoRA w multipack
|
||||||
@@ -94,9 +84,6 @@ class TestFAXentropyLlama:
|
|||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
check_tensorboard(
|
||||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high"
|
||||||
reader = SummaryReader(event_file)
|
)
|
||||||
df = reader.scalars # pylint: disable=invalid-name
|
|
||||||
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
|
||||||
assert df.value.values[-1] < 1.5, "Loss is too high"
|
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class TestFalconPatched(unittest.TestCase):
|
|||||||
"lora_dropout": 0.1,
|
"lora_dropout": 0.1,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"lora_modules_to_save": ["word_embeddings", "lm_head"],
|
"lora_modules_to_save": ["word_embeddings", "lm_head"],
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"bos_token": "<|endoftext|>",
|
"bos_token": "<|endoftext|>",
|
||||||
"pad_token": "<|endoftext|>",
|
"pad_token": "<|endoftext|>",
|
||||||
@@ -80,7 +80,7 @@ class TestFalconPatched(unittest.TestCase):
|
|||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"bos_token": "<|endoftext|>",
|
"bos_token": "<|endoftext|>",
|
||||||
"pad_token": "<|endoftext|>",
|
"pad_token": "<|endoftext|>",
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
@@ -21,6 +22,7 @@ LOG = logging.getLogger("axolotl.tests.e2e")
|
|||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("FIXME, mostly underused functionality")
|
||||||
class TestFusedLlama(unittest.TestCase):
|
class TestFusedLlama(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Test case for Llama models using Fused layers
|
Test case for Llama models using Fused layers
|
||||||
@@ -38,7 +40,7 @@ class TestFusedLlama(unittest.TestCase):
|
|||||||
"flash_attn_fuse_mlp": True,
|
"flash_attn_fuse_mlp": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.02,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"unk_token": "<unk>",
|
||||||
"bos_token": "<s>",
|
"bos_token": "<s>",
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"lora_alpha": 64,
|
"lora_alpha": 64,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.02,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"unk_token": "<unk>",
|
||||||
"bos_token": "<s>",
|
"bos_token": "<s>",
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class TestMistral(unittest.TestCase):
|
|||||||
"lora_alpha": 64,
|
"lora_alpha": 64,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"unk_token": "<unk>",
|
||||||
"bos_token": "<s>",
|
"bos_token": "<s>",
|
||||||
@@ -80,7 +80,7 @@ class TestMistral(unittest.TestCase):
|
|||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"unk_token": "<unk>",
|
||||||
"bos_token": "<s>",
|
"bos_token": "<s>",
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"lora_alpha": 32,
|
"lora_alpha": 32,
|
||||||
"lora_dropout": 0.1,
|
"lora_dropout": 0.1,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {},
|
"special_tokens": {},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -78,7 +78,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {},
|
"special_tokens": {},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class TestPhiMultipack(unittest.TestCase):
|
|||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"load_in_8bit": False,
|
"load_in_8bit": False,
|
||||||
"adapter": None,
|
"adapter": None,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"pad_token": "<|endoftext|>",
|
"pad_token": "<|endoftext|>",
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from e2e.utils import most_recent_subdir
|
|
||||||
from tbparse import SummaryReader
|
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
@@ -15,6 +13,8 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from ..utils import check_tensorboard
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
@@ -36,6 +36,9 @@ class TestUnslothQLoRA:
|
|||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"sample_packing": sample_packing,
|
"sample_packing": sample_packing,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
|
"unsloth_lora_mlp": True,
|
||||||
|
"unsloth_lora_qkv": True,
|
||||||
|
"unsloth_lora_o": True,
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 16,
|
"lora_r": 16,
|
||||||
@@ -73,18 +76,18 @@ class TestUnslothQLoRA:
|
|||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
check_tensorboard(
|
||||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||||
reader = SummaryReader(event_file)
|
)
|
||||||
df = reader.scalars # pylint: disable=invalid-name
|
|
||||||
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
|
||||||
assert df.value.values[-1] < 2.0, "Loss is too high"
|
|
||||||
|
|
||||||
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
|
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
|
"unsloth_lora_mlp": True,
|
||||||
|
"unsloth_lora_qkv": True,
|
||||||
|
"unsloth_lora_o": True,
|
||||||
"sample_packing": False,
|
"sample_packing": False,
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
@@ -123,12 +126,9 @@ class TestUnslothQLoRA:
|
|||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
check_tensorboard(
|
||||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||||
reader = SummaryReader(event_file)
|
)
|
||||||
df = reader.scalars # pylint: disable=invalid-name
|
|
||||||
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
|
||||||
assert df.value.values[-1] < 2.0, "Loss is too high"
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"sdp_attention",
|
"sdp_attention",
|
||||||
@@ -139,6 +139,9 @@ class TestUnslothQLoRA:
|
|||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
|
"unsloth_lora_mlp": True,
|
||||||
|
"unsloth_lora_qkv": True,
|
||||||
|
"unsloth_lora_o": True,
|
||||||
"sample_packing": False,
|
"sample_packing": False,
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
@@ -178,9 +181,6 @@ class TestUnslothQLoRA:
|
|||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
check_tensorboard(
|
||||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||||
reader = SummaryReader(event_file)
|
)
|
||||||
df = reader.scalars # pylint: disable=invalid-name
|
|
||||||
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
|
||||||
assert df.value.values[-1] < 2.0, "Loss is too high"
|
|
||||||
|
|||||||
@@ -7,15 +7,13 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from tbparse import SummaryReader
|
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import most_recent_subdir, with_temp_dir
|
from .utils import check_tensorboard, with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -66,12 +64,9 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
|||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
|
|
||||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
check_tensorboard(
|
||||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
|
||||||
reader = SummaryReader(event_file)
|
)
|
||||||
df = reader.scalars # pylint: disable=invalid-name
|
|
||||||
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
|
||||||
assert df.value.values[-1] < 2.0, "Loss is too high"
|
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_train_w_embedding_lr(self, temp_dir):
|
def test_train_w_embedding_lr(self, temp_dir):
|
||||||
@@ -113,9 +108,6 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
|||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
|
|
||||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
check_tensorboard(
|
||||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
|
||||||
reader = SummaryReader(event_file)
|
)
|
||||||
df = reader.scalars # pylint: disable=invalid-name
|
|
||||||
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
|
||||||
assert df.value.values[-1] < 2.0, "Loss is too high"
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from tbparse import SummaryReader
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
@@ -15,7 +14,7 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import most_recent_subdir, with_temp_dir
|
from .utils import check_tensorboard, with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -66,9 +65,6 @@ class TestPackedLlama(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
check_tensorboard(
|
||||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||||
reader = SummaryReader(event_file)
|
)
|
||||||
df = reader.scalars # pylint: disable=invalid-name
|
|
||||||
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
|
||||||
assert df.value.values[-1] < 2.0, "Loss is too high"
|
|
||||||
|
|||||||
@@ -7,15 +7,13 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from tbparse import SummaryReader
|
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import most_recent_subdir, with_temp_dir
|
from .utils import check_tensorboard, with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -85,9 +83,6 @@ class TestReLoraLlama(unittest.TestCase):
|
|||||||
).exists()
|
).exists()
|
||||||
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists()
|
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists()
|
||||||
|
|
||||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
check_tensorboard(
|
||||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high"
|
||||||
reader = SummaryReader(event_file)
|
)
|
||||||
df = reader.scalars # pylint: disable=invalid-name
|
|
||||||
df = df[(df.tag == "train/grad_norm")] # pylint: disable=invalid-name
|
|
||||||
assert df.value.values[-1] < 0.2, "grad_norm is too high"
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import torch
|
|||||||
|
|
||||||
# from importlib.metadata import version
|
# from importlib.metadata import version
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from tbparse import SummaryReader
|
||||||
|
|
||||||
|
|
||||||
def with_temp_dir(test_func):
|
def with_temp_dir(test_func):
|
||||||
@@ -66,3 +67,17 @@ def require_torch_2_5_1(test_case):
|
|||||||
def is_hopper():
|
def is_hopper():
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
return compute_capability == (9, 0)
|
return compute_capability == (9, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def check_tensorboard(
|
||||||
|
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
helper function to parse and check tensorboard logs
|
||||||
|
"""
|
||||||
|
tb_log_path = most_recent_subdir(temp_run_dir)
|
||||||
|
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||||
|
reader = SummaryReader(event_file)
|
||||||
|
df = reader.scalars # pylint: disable=invalid-name
|
||||||
|
df = df[(df.tag == tag)] # pylint: disable=invalid-name
|
||||||
|
assert df.value.values[-1] < lt_val, assertion_err
|
||||||
|
|||||||
25
tests/patched/test_llama_trainer_ga.py
Normal file
25
tests/patched/test_llama_trainer_ga.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
""""Test module for checking whether the Hugging Face Transformers is working as expected."""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.trainer_grad_accum import (
|
||||||
|
check_forward_is_patchable,
|
||||||
|
check_training_step_is_patchable,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrainerGAIntegration(unittest.TestCase):
|
||||||
|
"""llama monkeypatch integration tests."""
|
||||||
|
|
||||||
|
def test_train_step_patchable(self):
|
||||||
|
# ensures the current version of transformers has loss code that matches our patching code
|
||||||
|
self.assertTrue(
|
||||||
|
check_training_step_is_patchable(),
|
||||||
|
"HF transformers Trainer.training_step has changed and isn't patchable",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_model_forward_patchable(self):
|
||||||
|
# ensures the current version of transformers has loss code that matches our patching code
|
||||||
|
self.assertTrue(
|
||||||
|
check_forward_is_patchable(),
|
||||||
|
"HF transformers LlamaForCausalLM.forward has changed and isn't patchable",
|
||||||
|
)
|
||||||
@@ -4,6 +4,7 @@ shared fixtures for prompt strategies tests
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
@@ -60,6 +61,17 @@ def fixture_basic_dataset():
|
|||||||
|
|
||||||
@pytest.fixture(name="llama3_tokenizer")
|
@pytest.fixture(name="llama3_tokenizer")
|
||||||
def fixture_llama3_tokenizer():
|
def fixture_llama3_tokenizer():
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id="NousResearch/Meta-Llama-3-8B-Instruct",
|
||||||
|
filename="special_tokens_map.json",
|
||||||
|
)
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id="NousResearch/Meta-Llama-3-8B-Instruct",
|
||||||
|
filename="tokenizer_config.json",
|
||||||
|
)
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id="NousResearch/Meta-Llama-3-8B-Instruct", filename="tokenizer.json"
|
||||||
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
test module for the axolotl.utis.data module
|
test module for the axolotl.utils.data module
|
||||||
"""
|
"""
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user