Compare commits

..

38 Commits

Author SHA1 Message Date
Dan Saunders
156fede4f7 Update .pre-commit-config.yaml
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-03-21 10:36:18 -04:00
Dan Saunders
dcbbd7af79 sorry to revert, but pylint complained 2025-03-21 10:36:18 -04:00
Dan Saunders
21bac7ce1a running updated pre-commit plugins 2025-03-21 10:36:18 -04:00
Dan Saunders
aaa4571826 adding pre-commit auto-update GH action and bumping plugin versions 2025-03-21 10:36:17 -04:00
salman
187227d837 Fixing KTO+QLoRA+multi-GPU (#2420)
* WIP

* removing artifacts

* adding error

* adding adapter check

* linting

* simplifying check

* linting v2

* config fix -___-
2025-03-21 10:18:28 -04:00
NanoCode012
f8de8bb4f2 chore(doc): add instructions on adding custom integrations (#2422) [skip ci]
* chore(doc): add instructions on adding custom integrations

* chore: add warning help

* feat: add note about integration path

* fix: adjust text per suggestion
2025-03-21 10:18:01 -04:00
hugo
8e604848a4 add run on novita ai (#2421) [skip ci]
* add run on novita ai

* Revert "add run on novita ai"

This reverts commit 4d5df1ac6b.

* add run axolotl on novita ai
2025-03-21 10:17:47 -04:00
Wing Lian
aae4337f40 add 12.8.1 cuda to the base matrix (#2426)
* add 12.8.1 cuda to the base matrix

* use nightly

* bump deepspeed and set no binary

* deepspeed binary fixes hopefully

* install deepspeed by itself

* multiline fix

* make sure ninja is installed

* try with reversion of packaging/setuptools/wheel install

* use license instead of license-file

* try rolling back packaging and setuptools versions

* comment out license for validation for now

* make sure packaging version is consistent

* more parity across tests and docker images for packaging/setuptools
2025-03-21 10:17:25 -04:00
Wing Lian
38df5a36ea bump HF versions except for trl (#2427) 2025-03-20 10:22:05 -04:00
Wing Lian
4d92a68a96 use default torch fused adamw optimizer as default as adamw_hf is deprecated (#2425)
* use default torch fused adamw optimizer as default as adamw_hf is deprecated

* make sure to have latest packaging installed

* bump packagingin requirements.txt too
2025-03-19 23:58:33 -04:00
SicariusSicariiStuff
85147ec430 Update README.md (#2360)
* Update README.md

wheel is needed

* feat: add ninja, setuptools, packing to installation steps

* fix: add missing instruction

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-03-17 08:39:17 -04:00
NanoCode012
51cd409488 Feat: minor docs improvements for RLHF and faq on embeddings (#2401) [skip ci]
* feat: add doc on shrink_embeddings and custom calling

* chore: rename inference doc

* fix: clarify same config is used for all cli

* chore: rearrange order inference qmd

* feat: add simpo to doc

* fix: update defaults

* feat: add rl configs to doc

* fix: ensure beta consistent with trl.beta

* fix: clarify about lora/fft

* chore: rename title

* chore: fix language

* feat: move config reference higher

* Update docs/getting-started.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

* Update docs/rlhf.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-03-17 08:39:04 -04:00
NanoCode012
7235123d44 chore(docs): add cookbook/blog link to docs (#2410) [skip ci] 2025-03-17 08:38:19 -04:00
Wing Lian
4f5eb42a73 remove reference to deprecated import (#2407) 2025-03-15 08:49:41 -04:00
Wing Lian
fbe54be6b8 only validate hf user token on rank 0 (#2408) 2025-03-13 23:29:06 -04:00
Wing Lian
04f6324833 build cloud images with torch 2.6.0 (#2413)
* build cloud images with torch 2.6.0

* nightlies too
2025-03-13 23:28:51 -04:00
Wing Lian
f0072f3b9d use max of 32 dataset processes if not explicit (#2403)
* use max of 32 dataset processes if not explicit

* change alternate min val for consistency
2025-03-11 12:02:58 -04:00
Wing Lian
59899b9817 pass additional info for fix untrained tokens when using distributed + offloading (#2388)
* pass additional info for fix untrained tokens when using distributed + offloading

* use latest version of vendored lib

* use v0.0.5 of contribs lgpl

* fix for no bad tokens and add tests

* use release

* add multigpu test too

* make sure the multigpu zero3 test actually uses zero3
2025-03-11 12:02:43 -04:00
NanoCode012
4a736986fa fix(modal): add git pull when getting branch files (#2399) 2025-03-10 15:14:41 -04:00
Wing Lian
5d0f110a3b include iproute2 and nvtop in cloud image (#2393) 2025-03-10 15:13:38 -04:00
NanoCode012
83f8698b8a fix: create mount folder on modal if not exist (#2390) 2025-03-10 16:27:42 +07:00
xzuyn
60a11a6410 Use Latest Cut Cross Entropy (#2392)
* Update __init__.py

* Update README.md

* Update cutcrossentropy_install.py

* add test
2025-03-10 16:26:40 +07:00
NanoCode012
46a045e528 chore(doc): add faq when having no default chat_template (#2398)
* chore(doc): add faq when having no default chat_template

* Update docs/dataset-formats/conversation.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

* Update docs/faq.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-03-10 16:25:50 +07:00
NanoCode012
3b477e08a0 feat(doc): add more info on RewardModel datasets (#2391)
* fix: reduce title size

* feat(doc): add rm dataset info

* Update docs/reward_modelling.qmd following suggestion

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-03-10 16:25:31 +07:00
NanoCode012
16dc6ee68d refactor: trl grpo configs to have descriptions (#2386)
* refactor: trl grpo configs to have descriptions

* chore: caps
2025-03-07 08:58:53 -05:00
Wing Lian
fa7c79b3b9 remove lion-pytorch as it's already handled upstream (#2389) 2025-03-07 08:58:15 -05:00
Wing Lian
ae66374156 Optimizer refactor and add Muon support (#2367)
* add muon optimizer

optimizer_cls_and_kwargs is on trainer_kwargs
only add adamw_kwargs if they're non-null
fix mocks
better handling of override and check the optimizer
unwrap optimizer

* fix import
2025-03-06 11:49:19 -05:00
Wing Lian
5e21b1a9da various fixes 20250305 (#2384)
* various validation fixes

* fix check for non-truthy value
2025-03-06 11:48:44 -05:00
mhenrichsen
575e5f28ec Update Tokenizer Overrides Handling in models.py (#1549)
* override special tokens mock code

* fix(doc): remove duplicate config

* feat: replace added_tokens in tokenizer and add test

* make sure to run tokenizer modification on rank 0 only

* use is local main process instead

* feat: rename config

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-03-05 11:15:12 -05:00
xzuyn
0134093acc Add REX LR Scheduler (#2380)
* Update trainer_builder.py

* Update base.py

* Update __init__.py

* Update base.py

* Update base.py

* Update config.qmd

* Update base.py

* Update base.py

* Update base.py

* Update base.py

* Update base.py

* Update base.py

* Update base.py

* lint

* lint

* lint

* lint

* lint

* lint

* Update base.py

* Update base.py

* lint

* Update base.py

* Update base.py

* Move RexLR to `schedulers.py`

* Remove RexLR from `base.py`

* Fix tooltip formatting

* lint

* Create test_schedulers.py

* Use a default optimizer in test

* lint

* lint

* Add `warmup_steps` and `cosine_min_lr_ratio` to test

* lint
2025-03-05 10:26:11 -05:00
NanoCode012
d4de93a7bb feat(grpo): add reward_weights config and refactor (#2365) 2025-03-05 10:02:08 -05:00
NanoCode012
c8191394e9 fix(doc): add missing low_cpu_mem_usage config to docs (#2369) [skip ci] 2025-03-05 10:01:44 -05:00
NanoCode012
f18231c653 chore(doc): add clarification about mpi4py error on single gpu deepspeed (#2383) [skip ci]
* chore(doc): add clarification about mpi4py error on single gpu deepspeed

* fix: lint
2025-03-05 10:01:28 -05:00
NanoCode012
9ed4f6b3aa feat(doc): document drop_system_message and clarify limitation (#2381) [skip ci] 2025-03-05 10:01:16 -05:00
NanoCode012
05dddfc41d feat(doc): add docker images explanation (#2379) [skip ci]
* feat(doc): add docker images explanation

* chore: add link to dockerhub
2025-03-05 10:01:00 -05:00
NanoCode012
8e30917440 chore(docs): remove phorm (#2378) [skip ci] 2025-03-05 10:00:50 -05:00
NanoCode012
d883b11b6f fix(doc): add installation for cce to docs (#2375) [skip ci]
* fix(doc): add installation for cce to docs

* fix: format
2025-03-05 10:00:39 -05:00
Dan Saunders
f4910dd2ea train.py refactor (#2371)
* refactor train.py

* updates

* update

* combine like functions

* review comments
2025-03-05 08:58:33 -05:00
198 changed files with 2257 additions and 3100 deletions

View File

@@ -40,6 +40,12 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: nightly
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -61,7 +67,7 @@ jobs:
uses: docker/build-push-action@v4 uses: docker/build-push-action@v4
with: with:
context: . context: .
file: ./docker/Dockerfile-base file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || './docker/Dockerfile-base' }}
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -88,6 +88,11 @@ jobs:
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
is_latest: true is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -80,6 +80,11 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -0,0 +1,49 @@
name: Pre-commit auto-update
on:
schedule:
- cron: '0 0 * * 0' # Run weekly
workflow_dispatch: # Manual kickoff
jobs:
auto-update:
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Update pre-commit hooks
id: update
run: |
pip install pre-commit
pre-commit autoupdate
if [[ -n $(git status --porcelain) ]]; then
echo "changes=true" >> $GITHUB_OUTPUT
git diff .pre-commit-config.yaml > pre-commit-update.diff
fi
- name: Create Pull Request
if: steps.update.outputs.changes == 'true'
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
branch: update/pre-commit-hooks
delete-branch: true
title: "chore: update pre-commit hooks"
commit-message: "chore: update pre-commit hooks"
body: |
Automated PR to update pre-commit hooks to their latest versions.
<details>
<summary>Changes:</summary>
```diff
${{ steps.update.outputs.diff }}
```
</details>

View File

@@ -40,7 +40,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 install wheel packaging pip3 install wheel packaging==23.2
pip3 install --no-build-isolation -e . pip3 install --no-build-isolation -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt

View File

@@ -42,7 +42,7 @@ jobs:
- name: upgrade pip - name: upgrade pip
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch - name: Install PyTorch
run: | run: |
@@ -59,7 +59,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging pip3 install --upgrade packaging==23.2
pip3 install --no-build-isolation -U -e . pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh python scripts/cutcrossentropy_install.py | sh

View File

@@ -74,7 +74,7 @@ jobs:
- name: upgrade pip - name: upgrade pip
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch - name: Install PyTorch
run: | run: |
@@ -147,7 +147,7 @@ jobs:
- name: upgrade pip - name: upgrade pip
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools setuptools_scm build wheel pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
- name: Install PyTorch - name: Install PyTorch
run: | run: |

View File

@@ -3,7 +3,7 @@ default_language_version:
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0 rev: v5.0.0
hooks: hooks:
- id: check-yaml - id: check-yaml
- id: end-of-file-fixer - id: end-of-file-fixer
@@ -11,23 +11,23 @@ repos:
- id: no-commit-to-branch - id: no-commit-to-branch
args: ['--branch', 'main'] args: ['--branch', 'main']
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 23.3.0 rev: 25.1.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.12.0 rev: 6.0.1
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 6.1.0 rev: 7.1.2
hooks: hooks:
- id: flake8 - id: flake8
- repo: https://github.com/PyCQA/pylint - repo: https://github.com/pylint-dev/pylint
rev: v3.3.0 rev: v3.3.6
hooks: hooks:
- id: pylint - id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0 rev: v1.15.0
hooks: hooks:
- id: mypy - id: mypy
additional_dependencies: additional_dependencies:
@@ -36,7 +36,7 @@ repos:
'pydantic>=2.5.3', 'pydantic>=2.5.3',
] ]
- repo: https://github.com/PyCQA/bandit - repo: https://github.com/PyCQA/bandit
rev: 1.7.5 rev: 1.8.3
hooks: hooks:
- id: bandit - id: bandit
args: [ args: [

View File

@@ -19,9 +19,6 @@
<br/> <br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly"> <img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests"> <img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
</a>
</p> </p>
Axolotl is a tool designed to streamline post-training for various AI models. Axolotl is a tool designed to streamline post-training for various AI models.
@@ -58,6 +55,7 @@ Features:
### Installation ### Installation
```bash ```bash
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed] pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs # Download example axolotl configs, deepspeed configs

View File

@@ -32,14 +32,16 @@ website:
contents: contents:
- docs/getting-started.qmd - docs/getting-started.qmd
- docs/installation.qmd - docs/installation.qmd
- docs/cli.qmd
- docs/inference.qmd - docs/inference.qmd
- docs/cli.qmd
- docs/config.qmd
- section: "Dataset Formats" - section: "Dataset Formats"
contents: docs/dataset-formats/* contents: docs/dataset-formats/*
- section: "Deployments" - section: "Deployments"
contents: contents:
- docs/docker.qmd
- docs/multi-gpu.qmd - docs/multi-gpu.qmd
- docs/multi-node.qmd - docs/multi-node.qmd
- docs/ray-integration.qmd - docs/ray-integration.qmd
@@ -73,10 +75,6 @@ website:
- docs/debugging.qmd - docs/debugging.qmd
- docs/nccl.qmd - docs/nccl.qmd
- section: "Reference"
contents:
- docs/config.qmd
format: format:
html: html:
theme: darkly theme: darkly

View File

@@ -31,6 +31,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \ sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi fi
RUN pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \

View File

@@ -1,6 +1,7 @@
""" """
modal application to run axolotl gpu tests in Modal modal application to run axolotl gpu tests in Modal
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import os import os

View File

@@ -1,4 +1,5 @@
"""Modal app to run axolotl GPU tests""" """Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import os import os

View File

@@ -28,7 +28,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \ RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \ python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \ python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"

View File

@@ -0,0 +1,39 @@
ARG CUDA_VERSION="12.8.1"
ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="nightly"
ARG CUDA="128"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10

View File

@@ -14,7 +14,7 @@ COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \ RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux && \ RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
mkdir -p ~/.ssh && \ mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \ chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \ printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \

View File

@@ -1,5 +1,5 @@
--- ---
title: Config options title: Config Reference
description: A complete list of all configuration options. description: A complete list of all configuration options.
--- ---
@@ -30,6 +30,8 @@ tokenizer_legacy:
# Resize the model embeddings when new tokens are added to multiples of 32 # Resize the model embeddings when new tokens are added to multiples of 32
# This is reported to improve training speed on some models # This is reported to improve training speed on some models
resize_token_embeddings_to_32x: resize_token_embeddings_to_32x:
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
shrink_embeddings:
# (Internal use only) # (Internal use only)
# Used to identify which the model is based on # Used to identify which the model is based on
@@ -83,6 +85,12 @@ gpu_memory_limit: 20GiB
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge # Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
lora_on_cpu: true lora_on_cpu: true
# List[str]. Add plugins to extend the pipeline.
# See `src/axolotl/integrations` for the available plugins or doc below for more details.
# https://axolotl-ai-cloud.github.io/axolotl/docs/custom_integrations.html
plugins:
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# A list of one or more datasets to finetune the model with # A list of one or more datasets to finetune the model with
datasets: datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
@@ -154,8 +162,6 @@ datasets:
content: value content: value
# ... # ...
message_property_mappings:
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is: # Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles: roles:
user: ["human", "user"] user: ["human", "user"]
@@ -163,6 +169,12 @@ datasets:
system: ["system"] system: ["system"]
tool: ["tool"] tool: ["tool"]
# Optional[bool]. Whether to drop the system turn from the dataset. Only works with chat_template.
# This does not drop the default system message from chat_template if it exists. If you wish to,
# we recommend using a custom jinja template with the default system message removed or
# adding a system turn with empty content.
drop_system_message:
# IMPORTANT: The following fields determine which parts of the conversation to train on. # IMPORTANT: The following fields determine which parts of the conversation to train on.
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train # Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
# See examples at `docs/dataset-formats/conversation.qmd` # See examples at `docs/dataset-formats/conversation.qmd`
@@ -201,10 +213,46 @@ test_datasets:
data_files: data_files:
- /workspace/data/eval.jsonl - /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto' # use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'
rl: rl:
# whether to perform weighting if doing DPO training. Boolean. rl_beta: # Optional[float]. The beta parameter for the RL training.
dpo_use_weighting:
# dpo
dpo_use_weighting: # Optional[bool]. Whether to perform weighting.
rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper.
# orpo
orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping.
# kto
kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss.
kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss.
# simpo
cpo_alpha: 1.0 # Weight of the BC regularizer
simpo_gamma: 0.5 # Target reward margin for the SimPO loss
# grpo
trl:
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
vllm_device: # Optional[str]. Device to use for VLLM.
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
vllm_dtype: # Optional[str]. Data type for VLLM.
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir.
reward_weights: # Optional[list[float]]. List of reward weights for the reward functions.
num_generations: # Optional[int]. Number of generations to sample.
log_completions: # Optional[bool]. Whether to log completions.
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
# reward modelling: `True` or `False` # reward modelling: `True` or `False`
reward_model: reward_model:
@@ -222,13 +270,13 @@ process_reward_model:
chat_template: tokenizer_default chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null. # custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null chat_template_jinja: null
# Changes the default system message # Changes the default system message. Currently only supports chatml.
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml. default_system_message: You are a helpful assistant. Please give a long and detailed answer.
# Axolotl attempts to save the dataset as an arrow after packing the data together so # Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path # subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared dataset_prepared_path: data/last_run_prepared
# Push prepared dataset to hub # Push prepared dataset to hub
push_dataset_to_hub: # repo path push_dataset_to_hub: # Optional[str] repo_org/repo_name
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# if not set. # if not set.
dataset_processes: # defaults to os.cpu_count() if not set dataset_processes: # defaults to os.cpu_count() if not set
@@ -445,7 +493,7 @@ gradient_checkpointing: false
early_stopping_patience: 3 early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer # Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs: lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf) cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
@@ -528,6 +576,8 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
sdp_attention: sdp_attention:
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf # Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
s2_attention: s2_attention:
# Optional[bool]. Whether to use low_cpu_mem_usage
low_cpu_mem_usage:
# Resume from a specific checkpoint dir # Resume from a specific checkpoint dir
resume_from_checkpoint: resume_from_checkpoint:
# If resume_from_checkpoint isn't set and you simply want it to start where it left off. # If resume_from_checkpoint isn't set and you simply want it to start where it left off.
@@ -548,6 +598,13 @@ special_tokens:
# Add extra tokens. # Add extra tokens.
tokens: tokens:
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
# Can be checked if they exist in tokenizer.json added_tokens.
added_tokens_overrides: # Dict[int, str]
# 128041: "<|im_start|>"
# 128042: "<|im_end|>"
# FSDP # FSDP
fsdp: fsdp:
fsdp_config: fsdp_config:

View File

@@ -55,3 +55,47 @@ sections = [
for section_name, folder_name in sections: for section_name, folder_name in sections:
print(print_section(section_name, folder_name)) print(print_section(section_name, folder_name))
``` ```
## Adding a new integration
Plugins can be used to customize the behavior of the training pipeline through [hooks](https://en.wikipedia.org/wiki/Hooking). See [`axolotl.integrations.BasePlugin`](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/base.py) for the possible hooks.
To add a new integration, please follow these steps:
1. Create a new folder in the `src/axolotl/integrations` directory.
2. Add any relevant files (`LICENSE`, `README.md`, `ACKNOWLEDGEMENTS.md`, etc.) to the new folder.
3. Add `__init__.py` and `args.py` files to the new folder.
- `__init__.py` should import the integration and hook into the appropriate functions.
- `args.py` should define the arguments for the integration.
4. (If applicable) Add CPU tests under `tests/integrations` or GPU tests under `tests/e2e/integrations`.
::: {.callout-tip}
See [src/axolotl/integrations/cut_cross_entropy](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/cut_cross_entropy) for a minimal integration example.
:::
::: {.callout-warning}
If you could not load your integration, please ensure you are pip installing in editable mode.
```bash
pip install -e .
```
and correctly spelled the integration name in the config file.
```yaml
plugins:
- axolotl.integrations.your_integration_name.YourIntegrationPlugin
```
:::
::: {.callout-note}
It is not necessary to place your integration in the `integrations` folder. It can be in any location, so long as it's installed in a package in your python env.
See this repo for an example: [https://github.com/axolotl-ai-cloud/diff-transformer](https://github.com/axolotl-ai-cloud/diff-transformer)
:::

View File

@@ -74,6 +74,10 @@ datasets:
train_on_eos: train_on_eos:
``` ```
::: {.callout-tip}
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
:::
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages. 2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml ```yaml

View File

@@ -129,6 +129,7 @@ You can mix and match within each approach or across approaches to train a model
We suggest this approach when you want to bring your own tokenized dataset. We suggest this approach when you want to bring your own tokenized dataset.
Axolotl expects the dataset to have three keys: Axolotl expects the dataset to have three keys:
- `input_ids`: from tokenizing formatted prompt - `input_ids`: from tokenizing formatted prompt
- `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]` - `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]`
- `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`. - `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`.

140
docs/docker.qmd Normal file
View File

@@ -0,0 +1,140 @@
---
title: "Docker"
format:
html:
toc: true
toc-depth: 4
---
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
## Base
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
#### Image
```
axolotlai/axolotl-base
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)
#### Tags format
```bash
main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
```
Tags examples:
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
- `main-base-py3.11-cu124-2.4.1`
## Main
The main image is the image that is used to run Axolotl. It is based on the `axolotlai/axolotl-base` image and includes the Axolotl codebase, dependencies, and more.
#### Image
```
axolotlai/axolotl
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
#### Tags format {#sec-main-tags}
```bash
# on push to main
main-py{python_version}-cu{cuda_version}-{pytorch_version}
# latest main (currently torch 2.5.1, python 3.11, cuda 12.4)
main-latest
# nightly build
{branch}-{date_in_YYYYMMDD}-py{python_version}-cu{cuda_version}-{pytorch_version}
# tagged release
{version}
```
:::{.callout-tip}
There may be some extra tags appended to the image, like `-vllm` which installs those packages.
:::
Tags examples:
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-py3.11-cu124-2.4.1`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu124-2.5.1`
- `main-20250303-py3.11-cu124-2.4.1`
- `0.7.1`
## Cloud
The cloud image is the image that is used to run Axolotl in the cloud. It is based on the `axolotlai/axolotl` image and sets ENV variables like HuggingFace cache directories for volume mounts, tmux, and more for different cloud providers.
:::{.callout-tip}
Jupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variables to disable it.
:::
#### Image
```
axolotlai/axolotl-cloud
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)
#### Tags format
This uses the same tags as the [`main` image](#sec-main-tags).
#### Environment variables
- `JUPYTER_DISABLE`: Disable Jupyter lab.
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
- `PUBLIC_KEY`: Add a public key for the SSH service.
- `SSH_KEY`: Add a private key for the SSH service.
#### Volume mounts
:::{.callout-tip}
We recommend mounting volumes to `/workspace/data` for data persistence. `/workspace/axolotl` contains the source code and is ephemeral.
:::
- `/workspace/data/axolotl-artifacts`: Directory to store Axolotl artifacts.
- `/workspace/data/huggingface-cache`: Directory to store HuggingFace cache.
## Cloud-no-tmux
This is the same as the [`cloud` image](#sec-cloud) but without tmux.
#### Image
```
axolotlai/axolotl-cloud-term
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud-term)
:::{.callout-note}
The naming may be a bit confusing as it has `-term` appended to the end.
:::
#### Tags format
This uses the same tags as the [`cloud` image](#sec-cloud-tags).

View File

@@ -19,12 +19,24 @@ description: Frequently asked questions
**Q: AttributeError: 'DummyOptim' object has no attribute 'step'** **Q: AttributeError: 'DummyOptim' object has no attribute 'step'**
> A: You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli. **Q: ModuleNotFoundError: No module named 'mpi4py' using single GPU with deepspeed**
> A: You may be using deepspeed with single gpu. Please remove the `deepspeed:` section in the yaml file or `--deepspeed` CLI flag.
**Q: The codes is stuck on saving preprocessed datasets.** **Q: The codes is stuck on saving preprocessed datasets.**
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it. > A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
**Q: Received mismatch error on merge adapters / loading adapters between torch.Size of checkpoint and model.**
> A: This is likely due to vocab size mismatch. By default, Axolotl expands the model's embeddings if the tokenizer has more tokens than the model. Please use the `axolotl merge-lora` command to merge the adapters instead of using your own scripts.
> On the other hand, if the model has more tokens than the tokenizer, Axolotl does not shrink the model's embeddings unless `shrink_embeddings: true` is set in the config.
**Q: How to call Axolotl via custom python scripts?**
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
### Chat templates ### Chat templates
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`** **Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
@@ -50,3 +62,7 @@ description: Frequently asked questions
**Q: The EOS/EOT token is incorrectly being masked or not being masked.** **Q: The EOS/EOT token is incorrectly being masked or not being masked.**
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template. > A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.

View File

@@ -36,7 +36,9 @@ The YAML configuration file controls everything about your training. Here's what
```yaml ```yaml
base_model: NousResearch/Llama-3.2-1B base_model: NousResearch/Llama-3.2-1B
# hub_model_id: username/custom_model_name
load_in_8bit: true
adapter: lora
datasets: datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
@@ -44,11 +46,15 @@ datasets:
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.1 val_set_size: 0.1
output_dir: ./outputs/lora-out output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
``` ```
::: {.callout-tip}
`load_in_8bit: true` and `adapter: lora` enables LoRA adapter finetuning.
- To perform Full finetuning, remove these two lines.
- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.
:::
See our [Config options](config.qmd) for more details. See our [Config options](config.qmd) for more details.
### Training {#sec-training} ### Training {#sec-training}
@@ -56,7 +62,7 @@ See our [Config options](config.qmd) for more details.
When you run `axolotl train`, Axolotl: When you run `axolotl train`, Axolotl:
1. Downloads the base model 1. Downloads the base model
2. (If specified) applies LoRA adapter layers 2. (If specified) applies QLoRA/LoRA adapter layers
3. Loads and processes the dataset 3. Loads and processes the dataset
4. Runs the training loop 4. Runs the training loop
5. Saves the trained model and / or LoRA weights 5. Saves the trained model and / or LoRA weights
@@ -69,6 +75,8 @@ Let's modify the example for your own data:
```yaml ```yaml
base_model: NousResearch/Nous-Hermes-llama-1b-v1 base_model: NousResearch/Nous-Hermes-llama-1b-v1
load_in_8bit: true
adapter: lora adapter: lora
# Training settings # Training settings
@@ -104,8 +112,6 @@ format):
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"} {"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
``` ```
Please consult the supported [Dataset Formats](dataset-formats/) for more details.
3. Run the training: 3. Run the training:
```bash ```bash

View File

@@ -1,5 +1,5 @@
--- ---
title: "Inference" title: "Inference and Merging"
format: format:
html: html:
toc: true toc: true
@@ -9,10 +9,14 @@ execute:
enabled: false enabled: false
--- ---
This guide covers how to use your trained models for inference, including model loading, interactive testing, and common troubleshooting steps. This guide covers how to use your trained models for inference, including model loading, interactive testing, merging adapters, and common troubleshooting steps.
## Quick Start {#sec-quickstart} ## Quick Start {#sec-quickstart}
::: {.callout-tip}
Use the same config used for training on inference/merging.
:::
### Basic Inference {#sec-basic} ### Basic Inference {#sec-basic}
::: {.panel-tabset} ::: {.panel-tabset}

View File

@@ -22,6 +22,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
### PyPI Installation (Recommended) {#sec-pypi} ### PyPI Installation (Recommended) {#sec-pypi}
```{.bash} ```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed] pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
``` ```
@@ -37,7 +38,7 @@ For the latest features between releases:
```{.bash} ```{.bash}
git clone https://github.com/axolotl-ai-cloud/axolotl.git git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl cd axolotl
pip3 install packaging ninja pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]' pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
``` ```
@@ -65,6 +66,8 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
``` ```
::: :::
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
## Cloud Environments {#sec-cloud} ## Cloud Environments {#sec-cloud}
### Cloud GPU Providers {#sec-cloud-gpu} ### Cloud GPU Providers {#sec-cloud-gpu}
@@ -76,6 +79,7 @@ For providers supporting Docker:
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c) - [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl) - [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz) - [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
- [Novita](https://novita.ai/gpus-console?templateId=311)
### Google Colab {#sec-colab} ### Google Colab {#sec-colab}
@@ -105,7 +109,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
2. Install PyTorch: https://pytorch.org/get-started/locally/ 2. Install PyTorch: https://pytorch.org/get-started/locally/
3. Install Axolotl: 3. Install Axolotl:
```{.bash} ```{.bash}
pip3 install packaging pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]' pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
``` ```
4. (Optional) Login to Hugging Face: 4. (Optional) Login to Hugging Face:

View File

@@ -66,6 +66,10 @@ logic to be compatible with more of them.
</details> </details>
::: {.callout-tip}
Check out our [LoRA optimizations blog](https://axolotlai.substack.com/p/accelerating-lora-fine-tuning-with).
:::
## Usage ## Usage
These optimizations can be enabled in your Axolotl config YAML file. The These optimizations can be enabled in your Axolotl config YAML file. The

View File

@@ -28,8 +28,23 @@ val_set_size: 0.1
eval_steps: 100 eval_steps: 100
``` ```
Bradley-Terry chat templates expect single-turn conversations in the following format:
```json
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
```
### Process Reward Models (PRM) ### Process Reward Models (PRM)
::: {.callout-tip}
Check out our [PRM blog](https://axolotlai.substack.com/p/process-reward-models).
:::
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning. Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
```yaml ```yaml
base_model: Qwen/Qwen2.5-3B base_model: Qwen/Qwen2.5-3B
@@ -45,3 +60,5 @@ datasets:
val_set_size: 0.1 val_set_size: 0.1
eval_steps: 100 eval_steps: 100
``` ```
Please see [stepwise_supervised](dataset-formats/stepwise_supervised.qmd) for more details on the dataset format.

View File

@@ -3,6 +3,7 @@ title: "RLHF (Beta)"
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback." description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
back-to-top-navigation: true back-to-top-navigation: true
toc: true toc: true
toc-expand: 2
toc-depth: 4 toc-depth: 4
--- ---
@@ -297,7 +298,7 @@ The input format is a simple JSON input with customizable fields based on the ab
### IPO ### IPO
As IPO is just DPO with a different loss function, all supported options for DPO works here. As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
```yaml ```yaml
rl: ipo rl: ipo
@@ -343,8 +344,9 @@ ORPO supports the following types with the following dataset format:
```yaml ```yaml
rl: kto rl: kto
rl_beta: 0.5 rl_beta: 0.1 # default
kto_desirable_weight: 0.2 kto_desirable_weight: 1.0 # default
kto_undesirable_weight: 1.0 # default
remove_unused_columns: false remove_unused_columns: false
@@ -496,6 +498,10 @@ The input format is a simple JSON input with customizable fields based on the ab
### GRPO ### GRPO
::: {.callout-tip}
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
:::
GRPO uses custom reward functions and transformations. Please have them ready locally. GRPO uses custom reward functions and transformations. Please have them ready locally.
For ex, to load OpenAI's GSM8K and use a random reward for completions: For ex, to load OpenAI's GSM8K and use a random reward for completions:
@@ -528,6 +534,7 @@ trl:
vllm_gpu_memory_utilization: 0.15 vllm_gpu_memory_utilization: 0.15
num_generations: 4 num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}' reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
reward_weights: [1.0]
datasets: datasets:
- path: openai/gsm8k - path: openai/gsm8k
name: main name: main
@@ -536,6 +543,21 @@ datasets:
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function). To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
### SimPO
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
```yaml
rl: simpo
rl_beta: 0.1 # default in CPOTrainer
cpo_alpha: 1.0 # default in CPOTrainer
simpo_gamma: 0.5 # default in CPOTrainer
```
This method uses the same dataset format as [DPO](#dpo).
### Using local dataset files ### Using local dataset files
```yaml ```yaml

View File

@@ -1,59 +0,0 @@
---
title: Telemetry
description: A description of the opt-out telemetry implementation in Axolotl.
---
# Telemetry in Axolotl
Axolotl implements anonymous telemetry to help maintainers understand how the library
is used and where users encounter issues. This data helps prioritize features, optimize
performance, and fix bugs.
## Data Collection
We collect:
- System info: OS, Python version, Axolotl version, PyTorch version, Transformers
version, etc.
- Hardware info: CPU count, memory, GPU count and models
- Runtime metrics: Training progress, memory usage, timing information
- Usage patterns: Models (from a whitelist) and configurations used
- Error tracking: Stack traces and error messages (sanitized to remove personal
information)
No personally identifiable information (PII) is collected.
## Implementation
Telemetry is implemented using PostHog and consists of:
- `axolotl.telemetry.TelemetryManager`: A singleton class that initializes the
telemetry system and provides methods for tracking events.
- `axolotl.telemetry.errors.send_errors`: A decorator that captures exceptions and
sends sanitized stack traces.
- `axolotl.telemetry.runtime_metrics.RuntimeMetricsTracker`: A class that tracks
runtime metrics during training.
- `axolotl.telemetry.callbacks.TelemetryCallback`: A Trainer callback that sends
runtime metrics telemetry.
The telemetry system will block training startup for 15 seconds to ensure users are
aware of data collection, unless telemetry is explicitly enabled or disabled.
## Opt-Out Mechanism
Telemetry is **enabled by default** on an opt-out basis. To disable it, set either:
- `AXOLOTL_DO_NOT_TRACK=1` (Axolotl-specific)
- `DO_NOT_TRACK=1` (Global standard; see https://consoledonottrack.com/)
To acknowledge and explicitly enable telemetry (and remove the warning message), set:
`AXOLOTL_DO_NOT_TRACK=0`.
## Privacy
- All path-like config information is automatically redacted from telemetry data
- Model information is only collected for whitelisted organizations
- See `axolotl/telemetry/whitelist.yaml` for the set of whitelisted organizations
- Each run generates a unique anonymous ID
- This allows us to link different telemetry events in a single same training run
- Telemetry is only sent from the main process to avoid duplicate events

View File

@@ -55,7 +55,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
gradient_checkpointing_kwargs: gradient_checkpointing_kwargs:
use_reentrant: true use_reentrant: false
early_stopping_patience: early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:

View File

@@ -1,5 +1,5 @@
[build-system] [build-system]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"] requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==23.2"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[project] [project]
@@ -8,6 +8,7 @@ dynamic = ["version", "dependencies", "optional-dependencies"]
description = "LLM Trainer" description = "LLM Trainer"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
# license = "Apache-2.0"
[project.scripts] [project.scripts]
axolotl = "axolotl.cli.main:main" axolotl = "axolotl.cli.main:main"

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS # START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.2 bitsandbytes==0.45.3
triton>=3.0.0 triton>=3.0.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
flash-attn==2.7.4.post1 flash-attn==2.7.4.post1
@@ -12,12 +12,12 @@ liger-kernel==0.5.3
packaging==23.2 packaging==23.2
peft==0.14.0 peft==0.15.0
transformers==4.49.0 transformers==4.49.0
tokenizers>=0.21.0 tokenizers>=0.21.1
accelerate==1.3.0 accelerate==1.5.2
datasets==3.2.0 datasets==3.4.1
deepspeed==0.16.1 deepspeed==0.16.4
trl==0.15.1 trl==0.15.1
optimum==1.16.2 optimum==1.16.2
@@ -62,7 +62,5 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0 torchao==0.7.0
schedulefree==1.3.0 schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.3 axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3
# telemetry
posthog>=3.15.1

View File

@@ -1,6 +1,7 @@
""" """
helper script to parse chat datasets into a usable yaml helper script to parse chat datasets into a usable yaml
""" """
import click import click
import yaml import yaml
from datasets import load_dataset from datasets import load_dataset

View File

@@ -1,4 +1,5 @@
"""Script to output the correct installation command for cut-cross-entropy.""" """Script to output the correct installation command for cut-cross-entropy."""
import importlib.util import importlib.util
import sys import sys
@@ -24,5 +25,5 @@ if cce_spec:
print( print(
UNINSTALL_PREFIX UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"' + 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
) )

View File

@@ -128,7 +128,7 @@ setup(
"flash-attn==2.7.4.post1", "flash-attn==2.7.4.post1",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed==0.16.1", "deepspeed==0.16.4",
"deepspeed-kernels", "deepspeed-kernels",
], ],
"mamba-ssm": [ "mamba-ssm": [

View File

@@ -1,6 +1,7 @@
""" """
launch axolotl in supported cloud platforms launch axolotl in supported cloud platforms
""" """
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union

View File

@@ -1,6 +1,7 @@
""" """
base class for cloud platforms from cli base class for cloud platforms from cli
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod

View File

@@ -1,6 +1,7 @@
""" """
Modal Cloud support from CLI Modal Cloud support from CLI
""" """
import copy import copy
import json import json
import os import os
@@ -113,7 +114,7 @@ class ModalCloud(Cloud):
[ [
# Random id for cache busting of branch commits # Random id for cache busting of branch commits
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311 f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}", f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch} && git pull",
] ]
) )
@@ -270,6 +271,7 @@ def _preprocess(config_yaml: str, volumes=None):
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs): def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out: with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml) f_out.write(config_yaml)
run_folder = "/workspace/mounts" run_folder = "/workspace/mounts"
@@ -288,6 +290,7 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
def _lm_eval(config_yaml: str, volumes=None): def _lm_eval(config_yaml: str, volumes=None):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out: with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml) f_out.write(config_yaml)
run_folder = "/workspace/mounts" run_folder = "/workspace/mounts"

View File

@@ -14,8 +14,6 @@ import yaml
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.telemetry.errors import send_errors
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils.comet_ import setup_comet_env_vars from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import ( from axolotl.utils.config import (
normalize_cfg_datasets, normalize_cfg_datasets,
@@ -29,8 +27,6 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
TELEMETRY_MANAGER = TelemetryManager.get_instance()
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
""" """
@@ -156,7 +152,6 @@ def prepare_plugins(cfg: DictDefault):
plugin_manager.register(plugin_name) plugin_manager.register(plugin_name)
@send_errors
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault: def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
""" """
Loads the `axolotl` configuration stored at `config`, validates it, and performs Loads the `axolotl` configuration stored at `config`, validates it, and performs
@@ -176,7 +171,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
# Load the config from the yaml file # Load the config from the yaml file
with open(config, encoding="utf-8") as file: with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file)) cfg: DictDefault = DictDefault(yaml.safe_load(file))
TELEMETRY_MANAGER.send_event(event_type="config-loaded", properties=cfg)
# If there are any options passed in the cli, if it is something that seems valid # If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value # from the yaml, then overwrite the value
@@ -220,6 +214,4 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
setup_mlflow_env_vars(cfg) setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg) setup_comet_env_vars(cfg)
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
return cfg return cfg

View File

@@ -17,7 +17,6 @@ from axolotl.cli.args import InferenceCliArgs
from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.telemetry.errors import send_errors
from axolotl.utils.chat_templates import ( from axolotl.utils.chat_templates import (
get_chat_template, get_chat_template,
get_chat_template_from_config, get_chat_template_from_config,
@@ -43,7 +42,6 @@ def get_multi_line_input() -> str:
return instruction return instruction
@send_errors
def do_inference( def do_inference(
*, *,
cfg: DictDefault, cfg: DictDefault,
@@ -137,7 +135,6 @@ def do_inference(
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
@send_errors
def do_inference_gradio( def do_inference_gradio(
*, *,
cfg: DictDefault, cfg: DictDefault,

View File

@@ -1,4 +1,5 @@
"""Click CLI definitions for various axolotl commands.""" """Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
import logging import logging

View File

@@ -12,13 +12,11 @@ from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@send_errors
def do_merge_lora(*, cfg: DictDefault) -> None: def do_merge_lora(*, cfg: DictDefault) -> None:
""" """
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config

View File

@@ -27,7 +27,6 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.telemetry.errors import send_errors
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -121,7 +120,6 @@ def _distributed_checkpoint_to_merged_weights(
return save_path_ return save_path_
@send_errors
def merge_fsdp_weights( def merge_fsdp_weights(
checkpoint_dir: str, checkpoint_dir: str,
output_path: str, output_path: str,

View File

@@ -18,14 +18,12 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import disable_datasets_caching from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@send_errors
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
""" """
Preprocesses dataset specified in axolotl config. Preprocesses dataset specified in axolotl config.

View File

@@ -1,6 +1,7 @@
"""CLI to run training on a model.""" """CLI to run training on a model."""
import logging import logging
import os
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -34,18 +35,20 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
""" """
print_axolotl_text_art() print_axolotl_text_art()
check_accelerate_default_config() check_accelerate_default_config()
check_user_token() if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
if cfg.rl: if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta) model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
del model del model
del tokenizer del tokenizer
del trainer
plugin_manager.post_train_unload(cfg) plugin_manager.post_train_unload(cfg)

View File

@@ -5,7 +5,6 @@ import dataclasses
import hashlib import hashlib
import json import json
import logging import logging
import typing
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from types import NoneType from types import NoneType
@@ -24,7 +23,7 @@ configure_logging()
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def strip_optional_type(field_type: type | typing._SpecialForm | None): def strip_optional_type(field_type: type | str | None):
""" """
Extracts the non-`None` type from an `Optional` / `Union` type. Extracts the non-`None` type from an `Optional` / `Union` type.

View File

@@ -10,7 +10,6 @@ from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.telemetry.errors import send_errors
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -25,8 +24,8 @@ class TrainDatasetMeta:
"""Dataclass with fields for training and validation datasets and metadata.""" """Dataclass with fields for training and validation datasets and metadata."""
train_dataset: Dataset train_dataset: Dataset
eval_dataset: Optional[Dataset] = None eval_dataset: Dataset | None = None
total_num_steps: Optional[int] = None total_num_steps: int | None = None
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset: def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
@@ -45,7 +44,6 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
) )
@send_errors
def load_datasets( def load_datasets(
*, *,
cfg: DictDefault, cfg: DictDefault,
@@ -105,7 +103,6 @@ def load_datasets(
) )
@send_errors
def load_preference_datasets( def load_preference_datasets(
*, *,
cfg: DictDefault, cfg: DictDefault,

View File

@@ -1,6 +1,5 @@
"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes""" """Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
import json import json
import sys import sys

View File

@@ -1,6 +1,7 @@
""" """
ChatML transformation functions for MessageContents ChatML transformation functions for MessageContents
""" """
from typing import Optional from typing import Optional
from ..messages import MessageContents, Messages from ..messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
""" """
Llama 3.x chat formatting functions for MessageContents Llama 3.x chat formatting functions for MessageContents
""" """
from typing import Optional from typing import Optional
from ..messages import MessageContents, Messages from ..messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
""" """
shared functions for format transforms shared functions for format transforms
""" """
from axolotl.core.chat.messages import MessageContents, Messages from axolotl.core.chat.messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
""" """
internal message representations of chat messages internal message representations of chat messages
""" """
import json import json
from enum import Enum from enum import Enum
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union

View File

@@ -1,6 +1,7 @@
""" """
chat dataset module chat dataset module
""" """
import os import os
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
@@ -43,7 +44,7 @@ class TokenizedChatDataset(Dataset):
process_or_cpu_count: int = ( process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment] process_count or os.cpu_count() # type: ignore[assignment]
) )
num_proc = min(64, process_or_cpu_count) num_proc = min(32, process_or_cpu_count)
features = data.features.keys() features = data.features.keys()
tokenized_data = data.map( tokenized_data = data.map(
map_fn, map_fn,

View File

@@ -1,6 +1,7 @@
""" """
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat. This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
""" """
from typing import Any, Mapping, Union from typing import Any, Mapping, Union

View File

@@ -35,6 +35,7 @@ from transformers import (
EarlyStoppingCallback, EarlyStoppingCallback,
TrainerCallback, TrainerCallback,
) )
from transformers.training_args import OptimizerNames
from trl.trainer.utils import RewardDataCollatorWithPadding from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import ( from axolotl.core.trainers.base import (
@@ -61,8 +62,6 @@ from axolotl.core.training_args import (
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.telemetry.callbacks import TelemetryCallback
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
EvalFirstStepCallback, EvalFirstStepCallback,
@@ -86,6 +85,7 @@ from axolotl.utils.collators import (
V2BatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq,
) )
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
from axolotl.utils.models import ensure_dtype from axolotl.utils.models import ensure_dtype
try: try:
@@ -93,13 +93,11 @@ try:
except ImportError: except ImportError:
pass pass
LOG = logging.getLogger("axolotl.core.trainer_builder") LOG = logging.getLogger(__name__)
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
""" """Base class for trainer builder."""
Base class for trainer builder
"""
_train_dataset = None _train_dataset = None
_eval_dataset = None _eval_dataset = None
@@ -112,9 +110,9 @@ class TrainerBuilderBase(abc.ABC):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.processor = processor self.processor = processor
# in case the model supports tagging, add the axolotl tag. # If the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls # This makes sure the tag is correctly pushed even if a user calls
# model.push_to_hub instad of trainer.push_to_hub. # model.push_to_hub instead of trainer.push_to_hub.
if hasattr(model, "add_model_tags"): if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"]) model.add_model_tags(["axolotl"])
@@ -178,8 +176,10 @@ class TrainerBuilderBase(abc.ABC):
SaveAxolotlConfigtoMlflowCallback, SaveAxolotlConfigtoMlflowCallback,
) )
callbacks.append( callbacks.extend(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) [
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
]
) )
if self.cfg.use_comet and is_comet_available(): if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
@@ -188,10 +188,6 @@ class TrainerBuilderBase(abc.ABC):
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path) SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
) )
telemetry_manager = TelemetryManager.get_instance()
if telemetry_manager.enabled:
callbacks.append(TelemetryCallback())
return callbacks return callbacks
def get_post_trainer_create_callbacks(self, trainer): def get_post_trainer_create_callbacks(self, trainer):
@@ -231,8 +227,8 @@ class TrainerBuilderBase(abc.ABC):
class HFCausalTrainerBuilder(TrainerBuilderBase): class HFCausalTrainerBuilder(TrainerBuilderBase):
""" """
Build the HuggingFace training args/trainer for causal models Build the HuggingFace training args/trainer for causal models and reward modeling
and reward modelling using TRL. using TRL.
""" """
def get_callbacks(self): def get_callbacks(self):
@@ -336,9 +332,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs = {} training_arguments_kwargs = {}
if self.cfg.include_tokens_per_second is not None: if self.cfg.include_tokens_per_second is not None:
training_arguments_kwargs[ training_arguments_kwargs["include_tokens_per_second"] = (
"include_tokens_per_second" self.cfg.include_tokens_per_second
] = self.cfg.include_tokens_per_second )
if self.cfg.bf16 == "full": if self.cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True training_arguments_kwargs["bf16_full_eval"] = True
@@ -355,13 +351,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["seed"] = self.cfg.seed training_arguments_kwargs["seed"] = self.cfg.seed
if self.cfg.gradient_checkpointing: if self.cfg.gradient_checkpointing:
training_arguments_kwargs[ training_arguments_kwargs["gradient_checkpointing"] = (
"gradient_checkpointing" self.cfg.gradient_checkpointing
] = self.cfg.gradient_checkpointing )
if self.cfg.gradient_checkpointing_kwargs is not None: if self.cfg.gradient_checkpointing_kwargs is not None:
training_arguments_kwargs[ training_arguments_kwargs["gradient_checkpointing_kwargs"] = (
"gradient_checkpointing_kwargs" self.cfg.gradient_checkpointing_kwargs
] = self.cfg.gradient_checkpointing_kwargs )
if self.cfg.fsdp: if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config: if self.cfg.fsdp_config:
@@ -377,9 +373,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
if self.cfg.lr_quadratic_warmup is not None: if self.cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs[ training_arguments_kwargs["lr_quadratic_warmup"] = (
"lr_quadratic_warmup" self.cfg.lr_quadratic_warmup
] = self.cfg.lr_quadratic_warmup )
if self.cfg.adam_beta1: if self.cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1 training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
@@ -403,28 +399,28 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.dataloader_pin_memory is not None: if self.cfg.dataloader_pin_memory is not None:
training_arguments_kwargs[ training_arguments_kwargs["dataloader_pin_memory"] = (
"dataloader_pin_memory" self.cfg.dataloader_pin_memory
] = self.cfg.dataloader_pin_memory )
if self.cfg.dataloader_num_workers is not None: if self.cfg.dataloader_num_workers is not None:
training_arguments_kwargs[ training_arguments_kwargs["dataloader_num_workers"] = (
"dataloader_num_workers" self.cfg.dataloader_num_workers
] = self.cfg.dataloader_num_workers )
if self.cfg.dataloader_prefetch_factor is not None: if self.cfg.dataloader_prefetch_factor is not None:
training_arguments_kwargs[ training_arguments_kwargs["dataloader_prefetch_factor"] = (
"dataloader_prefetch_factor" self.cfg.dataloader_prefetch_factor
] = self.cfg.dataloader_prefetch_factor )
if self.cfg.dataloader_drop_last is not None: if self.cfg.dataloader_drop_last is not None:
training_arguments_kwargs[ training_arguments_kwargs["dataloader_drop_last"] = (
"dataloader_drop_last" self.cfg.dataloader_drop_last
] = self.cfg.dataloader_drop_last )
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False: elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
training_arguments_kwargs["dataloader_drop_last"] = True training_arguments_kwargs["dataloader_drop_last"] = True
if self.cfg.remove_unused_columns is not None: if self.cfg.remove_unused_columns is not None:
training_arguments_kwargs[ training_arguments_kwargs["remove_unused_columns"] = (
"remove_unused_columns" self.cfg.remove_unused_columns
] = self.cfg.remove_unused_columns )
if not self.cfg.test_datasets and self.cfg.val_set_size == 0: if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
# no eval set, so don't eval # no eval set, so don't eval
@@ -456,9 +452,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.do_causal_lm_eval: if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model: if self.cfg.metric_for_best_model:
training_arguments_kwargs[ training_arguments_kwargs["metric_for_best_model"] = (
"metric_for_best_model" self.cfg.metric_for_best_model
] = self.cfg.metric_for_best_model )
if self.cfg.greater_is_better: if self.cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
@@ -471,13 +467,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
) )
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend: if self.cfg.torch_compile_backend:
training_arguments_kwargs[ training_arguments_kwargs["torch_compile_backend"] = (
"torch_compile_backend" self.cfg.torch_compile_backend
] = self.cfg.torch_compile_backend )
if self.cfg.torch_compile_mode: if self.cfg.torch_compile_mode:
training_arguments_kwargs[ training_arguments_kwargs["torch_compile_mode"] = (
"torch_compile_mode" self.cfg.torch_compile_mode
] = self.cfg.torch_compile_mode )
# DDP Config # DDP Config
if self.cfg.ddp_timeout: if self.cfg.ddp_timeout:
@@ -486,32 +482,32 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.ddp_bucket_cap_mb: if self.cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
if self.cfg.ddp_broadcast_buffers is not None: if self.cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs[ training_arguments_kwargs["ddp_broadcast_buffers"] = (
"ddp_broadcast_buffers" self.cfg.ddp_broadcast_buffers
] = self.cfg.ddp_broadcast_buffers )
# these are all the "standard" kwargs that are def used # these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = ( training_arguments_kwargs["max_steps"] = (
total_num_steps if self.cfg.max_steps else -1 total_num_steps if self.cfg.max_steps else -1
) )
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs[ training_arguments_kwargs["per_device_train_batch_size"] = (
"per_device_train_batch_size" self.cfg.micro_batch_size
] = self.cfg.micro_batch_size )
if self.cfg.eval_batch_size: if self.cfg.eval_batch_size:
training_arguments_kwargs[ training_arguments_kwargs["per_device_eval_batch_size"] = (
"per_device_eval_batch_size" self.cfg.eval_batch_size
] = self.cfg.eval_batch_size )
if self.cfg.auto_find_batch_size is not None: if self.cfg.auto_find_batch_size is not None:
training_arguments_kwargs[ training_arguments_kwargs["auto_find_batch_size"] = (
"auto_find_batch_size" self.cfg.auto_find_batch_size
] = self.cfg.auto_find_batch_size )
training_arguments_kwargs[ training_arguments_kwargs["gradient_accumulation_steps"] = (
"gradient_accumulation_steps" self.cfg.gradient_accumulation_steps
] = self.cfg.gradient_accumulation_steps )
training_arguments_kwargs[ training_arguments_kwargs["eval_accumulation_steps"] = (
"eval_accumulation_steps" self.cfg.gradient_accumulation_steps
] = self.cfg.gradient_accumulation_steps )
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
training_arguments_kwargs["output_dir"] = self.cfg.output_dir training_arguments_kwargs["output_dir"] = self.cfg.output_dir
@@ -555,34 +551,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
else: else:
training_arguments_kwargs["run_name"] = None training_arguments_kwargs["run_name"] = None
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]: if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine" training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[ training_arguments_kwargs["alternate_lr_scheduler_type"] = (
"alternate_lr_scheduler_type" self.cfg.lr_scheduler
] = self.cfg.lr_scheduler )
else: else:
training_arguments_kwargs["lr_scheduler_type"] = ( training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine" self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
@@ -591,9 +565,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
) )
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
training_arguments_kwargs[ training_arguments_kwargs["cosine_constant_lr_ratio"] = (
"cosine_constant_lr_ratio" self.cfg.cosine_constant_lr_ratio
] = self.cfg.cosine_constant_lr_ratio )
training_arguments_kwargs["weight_decay"] = ( training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
) )
@@ -606,40 +580,40 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.eval_sample_packing self.cfg.eval_sample_packing
) )
if self.cfg.sample_packing_bin_size is not None: if self.cfg.sample_packing_bin_size is not None:
training_arguments_kwargs[ training_arguments_kwargs["sample_packing_bin_size"] = (
"sample_packing_bin_size" self.cfg.sample_packing_bin_size
] = self.cfg.sample_packing_bin_size )
if self.cfg.sample_packing_group_size is not None: if self.cfg.sample_packing_group_size is not None:
training_arguments_kwargs[ training_arguments_kwargs["sample_packing_group_size"] = (
"sample_packing_group_size" self.cfg.sample_packing_group_size
] = self.cfg.sample_packing_group_size )
if self.cfg.sample_packing_eff_est: if self.cfg.sample_packing_eff_est:
training_arguments_kwargs[ training_arguments_kwargs["sample_packing_efficiency"] = (
"sample_packing_efficiency" self.cfg.sample_packing_eff_est
] = self.cfg.sample_packing_eff_est )
if self.cfg.relora_steps: if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs[ training_arguments_kwargs["relora_warmup_steps"] = (
"relora_warmup_steps" self.cfg.relora_warmup_steps
] = self.cfg.relora_warmup_steps )
if self.cfg.relora_anneal_steps: if self.cfg.relora_anneal_steps:
training_arguments_kwargs[ training_arguments_kwargs["relora_anneal_steps"] = (
"relora_anneal_steps" self.cfg.relora_anneal_steps
] = self.cfg.relora_anneal_steps )
if self.cfg.relora_prune_ratio: if self.cfg.relora_prune_ratio:
training_arguments_kwargs[ training_arguments_kwargs["relora_prune_ratio"] = (
"relora_prune_ratio" self.cfg.relora_prune_ratio
] = self.cfg.relora_prune_ratio )
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
training_arguments_kwargs[ training_arguments_kwargs["lisa_step_interval"] = (
"lisa_step_interval" self.cfg.lisa_step_interval
] = self.cfg.lisa_step_interval )
training_arguments_kwargs[ training_arguments_kwargs["lisa_layers_attribute"] = (
"lisa_layers_attribute" self.cfg.lisa_layers_attribute
] = self.cfg.lisa_layers_attribute )
training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs training_arguments_kwargs
@@ -653,59 +627,127 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
) )
if self.cfg.neftune_noise_alpha is not None: if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[ training_arguments_kwargs["neftune_noise_alpha"] = (
"neftune_noise_alpha" self.cfg.neftune_noise_alpha
] = self.cfg.neftune_noise_alpha )
trainer_kwargs = {} trainer_kwargs = {}
if self.cfg.reward_model: if self.cfg.reward_model:
training_arguments_kwargs["max_length"] = self.cfg.sequence_len training_arguments_kwargs["max_length"] = self.cfg.sequence_len
# pylint: disable=duplicate-code # Handle custom optimizer
if self.cfg.optimizer in [ custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
"optimi_adamw", if self.cfg.optimizer in custom_supported_optimizers:
"ao_adamw_4bit", # Common optimizer kwargs
"ao_adamw_8bit", optimizer_kwargs = {
"ao_adamw_fp8", "lr": training_arguments_kwargs.get("learning_rate"),
"adopt_adamw", "weight_decay": training_arguments_kwargs.get("weight_decay"),
]: }
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
if self.cfg.optimizer == "lion_pytorch": # Adam-specific kwargs
from lion_pytorch import Lion adam_kwargs = {}
if training_arguments_kwargs.get(
"adam_beta1"
) and training_arguments_kwargs.get("adam_beta2"):
adam_kwargs["betas"] = (
training_arguments_kwargs.get("adam_beta1"),
training_arguments_kwargs.get("adam_beta2"),
)
if training_arguments_kwargs.get("adam_epsilon"):
adam_kwargs["eps"] = training_arguments_kwargs.get("adam_epsilon")
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]} if self.cfg.optimizer == "muon":
if "weight_decay" in training_arguments_kwargs: from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"] MuonOptimizerFactory,
if (
"adam_beta1" in training_arguments_kwargs
and "adam_beta2" in training_arguments_kwargs
):
lion_kwargs["betas"] = (
training_arguments_kwargs["adam_beta1"],
training_arguments_kwargs["adam_beta2"],
) )
trainer_kwargs["optimizers"] = ( optimizer_cls = MuonOptimizerFactory
Lion(params=self.model.parameters(), **lion_kwargs), optimizer_kwargs.update(adam_kwargs)
None, elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
optimizer_kwargs["foreach"] = False
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_4bit":
# TODO remove 20250401
from torchao.prototype.low_bit_optim import AdamW4bit
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
)
elif self.cfg.optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
optimizer_cls = AdamW8bit
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
optimizer_cls = AdamWFp8
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
optimizer_cls = ADOPT
adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs)
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optimizer_kwargs.update(self.cfg.optim_args)
else:
# Parse string format "key1=value1,key2=value2"
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optimizer_kwargs[key] = value
trainer_kwargs["optimizer_cls_and_kwargs"] = (
optimizer_cls,
optimizer_kwargs,
) )
# Set default so transformers doesn't throw else:
training_arguments_kwargs["optim"] = "adamw_hf" # Use transformers' optimizer
training_arguments_kwargs["optim"] = self.cfg.optimizer
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optimizer == "adamw_anyprecision": if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists(): if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path) sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx") importlib.import_module("torchdistx")
if self.cfg.optim_target_modules:
training_arguments_kwargs["optim_target_modules"] = (
self.cfg.optim_target_modules
)
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs["loraplus_lr_embedding"] = (
self.cfg.loraplus_lr_embedding
)
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.accelerator_config: if self.cfg.accelerator_config:
training_arguments_kwargs[ training_arguments_kwargs["accelerator_config"] = (
"accelerator_config" self.cfg.accelerator_config
] = self.cfg.accelerator_config )
if self.cfg.kd_ce_alpha is not None: if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
@@ -714,13 +756,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.kd_temperature is not None: if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None: if self.cfg.kd_zscore_base_temp is not None:
training_arguments_kwargs[ training_arguments_kwargs["kd_zscore_base_temp"] = (
"kd_zscore_base_temp" self.cfg.kd_zscore_base_temp
] = self.cfg.kd_zscore_base_temp )
if self.cfg.kd_top_k_before_softmax is not None: if self.cfg.kd_top_k_before_softmax is not None:
training_arguments_kwargs[ training_arguments_kwargs["kd_top_k_before_softmax"] = (
"kd_top_k_before_softmax" self.cfg.kd_top_k_before_softmax
] = self.cfg.kd_top_k_before_softmax )
if self.cfg.reward_model: if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig training_args_cls = AxolotlRewardConfig
@@ -876,9 +918,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
class HFRLTrainerBuilder(TrainerBuilderBase): class HFRLTrainerBuilder(TrainerBuilderBase):
""" """Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
"""
def get_callbacks(self): def get_callbacks(self):
callbacks = super().get_callbacks() callbacks = super().get_callbacks()
@@ -932,32 +972,32 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
) )
if self.cfg.remove_unused_columns is not None: if self.cfg.remove_unused_columns is not None:
training_args_kwargs[ training_args_kwargs["remove_unused_columns"] = (
"remove_unused_columns" self.cfg.remove_unused_columns
] = self.cfg.remove_unused_columns )
else: else:
training_args_kwargs["remove_unused_columns"] = False training_args_kwargs["remove_unused_columns"] = False
if self.cfg.dataloader_pin_memory is not None: if self.cfg.dataloader_pin_memory is not None:
training_args_kwargs[ training_args_kwargs["dataloader_pin_memory"] = (
"dataloader_pin_memory" self.cfg.dataloader_pin_memory
] = self.cfg.dataloader_pin_memory )
if self.cfg.dataloader_num_workers is not None: if self.cfg.dataloader_num_workers is not None:
training_args_kwargs[ training_args_kwargs["dataloader_num_workers"] = (
"dataloader_num_workers" self.cfg.dataloader_num_workers
] = self.cfg.dataloader_num_workers )
if self.cfg.dataloader_prefetch_factor is not None: if self.cfg.dataloader_prefetch_factor is not None:
training_args_kwargs[ training_args_kwargs["dataloader_prefetch_factor"] = (
"dataloader_prefetch_factor" self.cfg.dataloader_prefetch_factor
] = self.cfg.dataloader_prefetch_factor )
if self.cfg.gradient_checkpointing: if self.cfg.gradient_checkpointing:
training_args_kwargs[ training_args_kwargs["gradient_checkpointing"] = (
"gradient_checkpointing" self.cfg.gradient_checkpointing
] = self.cfg.gradient_checkpointing )
if self.cfg.gradient_checkpointing_kwargs is not None: if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs[ training_args_kwargs["gradient_checkpointing_kwargs"] = (
"gradient_checkpointing_kwargs" self.cfg.gradient_checkpointing_kwargs
] = self.cfg.gradient_checkpointing_kwargs )
else: else:
training_args_kwargs["gradient_checkpointing_kwargs"] = { training_args_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False "use_reentrant": False
@@ -1031,9 +1071,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dpo_use_weighting is not None: if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None: if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs[ training_args_kwargs["use_logits_to_keep"] = (
"use_logits_to_keep" self.cfg.dpo_use_logits_to_keep
] = self.cfg.dpo_use_logits_to_keep )
for blocklist_key in blocklist_args_kwargs: for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs: if blocklist_key in training_args_kwargs:
@@ -1068,9 +1108,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.adapter and self.peft_config: if self.cfg.adapter and self.peft_config:
dpo_trainer_kwargs["peft_config"] = self.peft_config dpo_trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None: if self.cfg.precompute_ref_log_probs is not None:
dpo_trainer_kwargs[ dpo_trainer_kwargs["precompute_ref_log_probs"] = (
"precompute_ref_log_probs" self.cfg.precompute_ref_log_probs
] = self.cfg.precompute_ref_log_probs )
if self.cfg.rl == "grpo": if self.cfg.rl == "grpo":
trainer_cls = GRPOStrategy.get_trainer_class() trainer_cls = GRPOStrategy.get_trainer_class()
trainer_cls_args = [self.model] trainer_cls_args = [self.model]

View File

@@ -14,6 +14,7 @@ from typing import Dict, Literal, Optional
import torch import torch
from datasets import Dataset from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import Trainer from transformers import Trainer
@@ -22,9 +23,11 @@ from transformers.utils import is_sagemaker_mp_enabled
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length from trl.trainer.utils import pad_to_length
from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.monkeypatch.relora import ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRAScheduler
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import ( from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr, get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup, get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant, get_cosine_schedule_with_warmup_decay_constant,
@@ -115,6 +118,17 @@ class SchedulerMixin(Trainer):
**extra_lr_kwargs, **extra_lr_kwargs,
**self.args.lr_scheduler_kwargs, **self.args.lr_scheduler_kwargs,
) )
elif self.args.alternate_lr_scheduler_type == "rex":
if use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = RexLR(
optimizer=optimizer,
max_lr=self.args.learning_rate,
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
total_steps=num_training_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
)
elif use_cosine_quadratic: elif use_cosine_quadratic:
if use_cosine_min_lr: if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
@@ -154,47 +168,18 @@ class SchedulerMixin(Trainer):
return self.lr_scheduler return self.lr_scheduler
class AxolotlTrainer(SchedulerMixin, Trainer): class OptimizerMixin(Trainer):
""" """
Extend the base Trainer for axolotl helpers Mixin class for shared handling of building custom optimizers
""" """
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
def __init__( def create_optimizer_grouped_parameters(
self, self, opt_model, optimizer_kwargs
*_args, ) -> list[dict]:
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
decay_parameters = self.get_decay_parameter_names(opt_model) decay_parameters = self.get_decay_parameter_names(opt_model)
params = { params: dict = {
"to_weight_decay": {}, # LayerNorm and bias "to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens, "embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {}, "no_weight_decay": {},
@@ -281,23 +266,30 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
and self.args.embedding_lr_scale is None and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None and self.args.embedding_lr is None
and self.args.lr_groups is None and self.args.lr_groups is None
and self.args.alternate_optimizer and self.optimizer_cls_and_kwargs is None
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"adopt_adamw",
]
): ):
return super().create_optimizer() return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( if (
self.args, not self.optimizer
opt_model, and self.optimizer_cls_and_kwargs is not None
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
):
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
self.optimizer = optimizer_factory_cls()(
opt_model, self.args, **optimizer_kwargs
) )
if not self.optimizer:
if self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
self.args, opt_model
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters( optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs opt_model, optimizer_kwargs
) )
@@ -314,50 +306,47 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
loraplus_lr_embedding=loraplus_lr_embedding, loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs, **optimizer_kwargs,
) )
elif ( else:
self.args.embedding_lr_scale is not None # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
or self.args.embedding_lr is not None # e.g. for GaLore optimizer.
or self.args.lr_groups is not None if "params" in optimizer_kwargs:
): optimizer_grouped_parameters = optimizer_kwargs.pop("params")
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
self.optimizer = ( # pylint: disable=attribute-defined-outside-init # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
AdamW( # e.g. for LOMO optimizer.
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop(
"optimizer_dict"
) )
)
elif self.args.alternate_optimizer == "ao_adamw_4bit":
from torchao.prototype.low_bit_optim import AdamW4bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init self.optimizer = optimizer_cls(
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs) optimizer_grouped_parameters, **optimizer_kwargs
) )
elif self.args.alternate_optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init if optimizer_cls.__name__ == "Adam8bit":
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs) import bitsandbytes
)
elif self.args.alternate_optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
self.optimizer = ( # pylint: disable=attribute-defined-outside-init manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
self.optimizer = ( # pylint: disable=attribute-defined-outside-init skipped = 0
ADOPT( for module in opt_model.modules():
optimizer_grouped_parameters, if isinstance(module, nn.Embedding):
decouple=True, skipped += sum(
**optimizer_kwargs, {
) p.data_ptr(): p.numel() for p in module.parameters()
) }.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -366,6 +355,45 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
return self.optimizer return self.optimizer
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
def __init__(
self,
*_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining: if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches: if self.args.multipack_real_batches:
@@ -434,9 +462,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
"pin_memory": self.args.dataloader_pin_memory, "pin_memory": self.args.dataloader_pin_memory,
} }
if self.args.dataloader_prefetch_factor: if self.args.dataloader_prefetch_factor:
dataloader_params[ dataloader_params["prefetch_factor"] = (
"prefetch_factor" self.args.dataloader_prefetch_factor
] = self.args.dataloader_prefetch_factor )
sampler = self._get_train_sampler() sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler): if isinstance(sampler, BatchSampler):
@@ -481,9 +509,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
"pin_memory": self.args.dataloader_pin_memory, "pin_memory": self.args.dataloader_pin_memory,
} }
if self.args.dataloader_prefetch_factor: if self.args.dataloader_prefetch_factor:
dataloader_params[ dataloader_params["prefetch_factor"] = (
"prefetch_factor" self.args.dataloader_prefetch_factor
] = self.args.dataloader_prefetch_factor )
if isinstance(eval_sampler, BatchSampler): if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler dataloader_params["batch_sampler"] = eval_sampler

View File

@@ -1,6 +1,7 @@
""" """
DPO Specific Strategy for training DPO Specific Strategy for training
""" """
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer

View File

@@ -1,6 +1,7 @@
""" """
Axolotl specific DPO args Axolotl specific DPO args
""" """
from dataclasses import dataclass from dataclasses import dataclass
from trl import DPOConfig from trl import DPOConfig

View File

@@ -1,6 +1,7 @@
""" """
DPO trainer for axolotl DPO trainer for axolotl
""" """
import gc import gc
from functools import wraps from functools import wraps
from typing import Any, Dict, Union from typing import Any, Dict, Union

View File

@@ -9,6 +9,7 @@ import logging
from trl.trainer.grpo_trainer import RewardFunc from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -31,30 +32,44 @@ class GRPOStrategy:
@classmethod @classmethod
def set_training_args_kwargs(cls, cfg): def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {} grpo_args_kwargs = {}
if cfg.trl and cfg.trl.use_vllm:
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm if not hasattr(cfg, "trl") or not cfg.trl:
if cfg.trl and cfg.trl.vllm_device: return grpo_args_kwargs
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
else: trl: TRLConfig = cfg.trl # type: ignore
grpo_args_kwargs["vllm_device"] = "auto"
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization: if trl.use_vllm:
grpo_args_kwargs[ grpo_args_kwargs["use_vllm"] = trl.use_vllm
"vllm_gpu_memory_utilization" grpo_args_kwargs["vllm_device"] = (
] = cfg.trl.vllm_gpu_memory_utilization trl.vllm_device if trl.vllm_device else "auto"
if cfg.trl and cfg.trl.vllm_max_model_len: )
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
if cfg.trl and cfg.trl.num_generations: if trl.vllm_gpu_memory_utilization:
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
if cfg.trl and cfg.trl.sync_ref_model: trl.vllm_gpu_memory_utilization
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model )
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
grpo_args_kwargs[ if trl.vllm_max_model_len:
"ref_model_mixup_alpha" grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
] = cfg.trl.ref_model_mixup_alpha
if cfg.trl and cfg.trl.ref_model_sync_steps: if trl.num_generations:
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps grpo_args_kwargs["num_generations"] = trl.num_generations
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions if trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model
if trl.ref_model_mixup_alpha:
grpo_args_kwargs["ref_model_mixup_alpha"] = trl.ref_model_mixup_alpha
if trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
grpo_args_kwargs["log_completions"] = trl.log_completions
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights
return grpo_args_kwargs return grpo_args_kwargs
@classmethod @classmethod
@@ -71,9 +86,9 @@ class GRPOStrategy:
def set_trainer_kwargs(cls, cfg): def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {} trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_processing_classes: if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs[ trainer_kwargs["reward_processing_classes"] = (
"reward_processing_classes" cfg.trl.reward_processing_classes
] = cfg.trl.reward_processing_classes )
return trainer_kwargs return trainer_kwargs
@classmethod @classmethod

View File

@@ -1,6 +1,7 @@
""" """
Axolotl Specific Training Args Axolotl Specific Training Args
""" """
from dataclasses import dataclass from dataclasses import dataclass
from trl import GRPOConfig from trl import GRPOConfig

View File

@@ -1,6 +1,7 @@
""" """
Axolotl GRPO trainer Axolotl GRPO trainer
""" """
from accelerate.utils import is_peft_model from accelerate.utils import is_peft_model
from accelerate.utils.other import is_compiled_module from accelerate.utils.other import is_compiled_module
from transformers import PreTrainedModel from transformers import PreTrainedModel

View File

@@ -1,6 +1,7 @@
""" """
module for TRL PPO training module for TRL PPO training
""" """
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from trl import PPOTrainer from trl import PPOTrainer

View File

@@ -1,6 +1,7 @@
""" """
extra axolotl specific training args extra axolotl specific training args
""" """
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional

View File

@@ -10,7 +10,6 @@ import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.telemetry.errors import send_errors
from axolotl.train import TrainDatasetMeta from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -62,7 +61,6 @@ def evaluate_dataset(
return metrics return metrics
@send_errors
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]: def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
""" """
Evaluate a model on training and validation datasets Evaluate a model on training and validation datasets

View File

@@ -23,6 +23,8 @@ import importlib
import logging import logging
from typing import OrderedDict from typing import OrderedDict
import torch
class BasePlugin: class BasePlugin:
""" """
@@ -469,3 +471,14 @@ class PluginManager:
""" """
for plugin in self.plugins.values(): for plugin in self.plugins.values():
plugin.post_train_unload(cfg) plugin.post_train_unload(cfg)
class BaseOptimizerFactory:
"""
Base class for factories to create custom optimizers
"""
def __call__(
self, opt_model, training_args, **optimizer_kwargs
) -> "torch.optim.Optimizer":
pass

View File

@@ -4,6 +4,22 @@ Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy o
See https://github.com/apple/ml-cross-entropy See https://github.com/apple/ml-cross-entropy
## Requirements
- PyTorch 2.4.0 or higher
## Installation
Run the following command to install `cut_cross_entropy[transformers]` if you don't have it already.
```bash
# if you are in dev environment
python scripts/cutcrossentropy_install.py | sh
# if you are not in dev environment
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"
```
## Usage ## Usage
```yaml ```yaml

View File

@@ -33,7 +33,7 @@ LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
_CCE_INSTALL_MESSAGE = ( _CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using " "Please install cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers]==24.11.4"`' '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"`'
) )

View File

@@ -1,6 +1,7 @@
""" """
Grokfast plugin for Axolotl Grokfast plugin for Axolotl
""" """
import logging import logging
from transformers.trainer_callback import TrainerCallback from transformers.trainer_callback import TrainerCallback

View File

@@ -1,6 +1,7 @@
""" """
config args for grokfast plugin config args for grokfast plugin
""" """
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -26,12 +26,12 @@ class KDArgs(BaseModel):
""" """
kd_trainer: Optional[bool] = None # whether to use KD trainer kd_trainer: Optional[bool] = None # whether to use KD trainer
kd_ce_alpha: Optional[ kd_ce_alpha: Optional[float] = (
float None # loss coefficient for cross-entropy loss during KD
] = None # loss coefficient for cross-entropy loss during KD )
kd_alpha: Optional[float] = None # loss coefficient for KD loss kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[ kd_top_k_before_softmax: Optional[bool] = (
bool None # whether to sample top k before softmax during KD
] = None # whether to sample top k before softmax during KD )

View File

@@ -55,9 +55,9 @@ class LigerPlugin(BasePlugin):
if "cross_entropy" in liger_fn_sig.parameters: if "cross_entropy" in liger_fn_sig.parameters:
kwargs["cross_entropy"] = cfg.liger_cross_entropy kwargs["cross_entropy"] = cfg.liger_cross_entropy
if "fused_linear_cross_entropy" in liger_fn_sig.parameters: if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
kwargs[ kwargs["fused_linear_cross_entropy"] = (
"fused_linear_cross_entropy" cfg.liger_fused_linear_cross_entropy
] = cfg.liger_fused_linear_cross_entropy )
if "rms_norm" in liger_fn_sig.parameters: if "rms_norm" in liger_fn_sig.parameters:
kwargs["rms_norm"] = cfg.liger_rms_norm kwargs["rms_norm"] = cfg.liger_rms_norm
if "layer_norm" in liger_fn_sig.parameters: if "layer_norm" in liger_fn_sig.parameters:

View File

@@ -1,6 +1,7 @@
""" """
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
""" """
Jamba model with LigerFusedLinearCrossEntropyLoss Jamba model with LigerFusedLinearCrossEntropyLoss
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
""" """
Module for the Plugin for LM Eval Harness Module for the Plugin for LM Eval Harness
""" """
import subprocess # nosec import subprocess # nosec
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin

View File

@@ -1,6 +1,7 @@
""" """
Module for handling lm eval harness input arguments. Module for handling lm eval harness input arguments.
""" """
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -1,6 +1,7 @@
""" """
axolotl CLI for running lm_eval tasks axolotl CLI for running lm_eval tasks
""" """
import subprocess # nosec import subprocess # nosec
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime

View File

@@ -17,7 +17,7 @@ Module for handling Spectrum input arguments.
""" """
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel, model_validator
class SpectrumArgs(BaseModel): class SpectrumArgs(BaseModel):
@@ -27,3 +27,20 @@ class SpectrumArgs(BaseModel):
spectrum_top_fraction: Optional[float] = 0.5 spectrum_top_fraction: Optional[float] = 0.5
spectrum_model_name: Optional[str] = None spectrum_model_name: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_fsdp_use_orig_params(cls, data):
if (
data.get("fsdp")
and data.get("fsdp_config")
and not data["fsdp_config"].get("use_orig_params")
and data.get("plugins")
and any("SpectrumPlugin" in plugin for plugin in data["plugins"])
):
# would otherwise raise
# ValueError: Must flatten tensors with uniform `requires_grad` when `use_orig_params=False`
raise ValueError(
"FSDP + SpectrumPlugin cannot be used together when `use_orig_params=False` is set"
)
return data

View File

@@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
""" """
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code # pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
import torch import torch

View File

@@ -6,6 +6,7 @@ See "LoRA: Low-Rank Adaptation of Large Language Models"
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
""" """
# pylint: disable=invalid-name # pylint: disable=invalid-name
from typing import Callable from typing import Callable

View File

@@ -1,4 +1,5 @@
"""Dequantization utilities for `bitsandbytes` integration.""" """Dequantization utilities for `bitsandbytes` integration."""
# pylint: disable=invalid-name,global-statement # pylint: disable=invalid-name,global-statement
import ctypes import ctypes

View File

@@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
""" """
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl

View File

@@ -1,6 +1,7 @@
""" """
HF Transformers MambaConfig HF Transformers MambaConfig
""" """
from transformers import PretrainedConfig from transformers import PretrainedConfig

View File

@@ -1,6 +1,7 @@
""" """
Monkeypatch for Vision Llama for FA2 support Monkeypatch for Vision Llama for FA2 support
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -220,10 +221,10 @@ def patch_mllama():
True True
) )
MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2 MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[ MLLAMA_TEXT_CROSS_ATTENTION_CLASSES["flash_attention_2"] = (
"flash_attention_2" MllamaTextCrossFlashAttention2
] = MllamaTextCrossFlashAttention2 )
# fallback to SDPA # fallback to SDPA
MLLAMA_VISION_ATTENTION_CLASSES[ MLLAMA_VISION_ATTENTION_CLASSES["flash_attention_2"] = (
"flash_attention_2" MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]
] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"] )

View File

@@ -1,4 +1,5 @@
"""monkey patches for the dataset fetcher to handle batches of packed indexes""" """monkey patches for the dataset fetcher to handle batches of packed indexes"""
# pylint: disable=protected-access # pylint: disable=protected-access
import torch import torch

View File

@@ -12,7 +12,9 @@ import transformers
from einops import rearrange from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention from transformers.models.llama.modeling_llama import (
LlamaAttention,
)
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer, LlamaDecoderLayer as OriginalLlamaDecoderLayer,
) )
@@ -490,9 +492,11 @@ def flashattn_forward(
# We have disabled _prepare_decoder_attention_mask in LlamaModel # We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask # the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask, key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :] query_padding_mask=(
if attention_mask is not None attention_mask[:, -query_states.size(1) :]
else None, if attention_mask is not None
else None
),
) )
output_unpad = flash_attn_varlen_qkvpacked_func( output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad, qkv_unpad,
@@ -531,9 +535,11 @@ def flashattn_forward(
value_states, value_states,
kvpacked=True, kvpacked=True,
key_padding_mask=attention_mask, key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :] query_padding_mask=(
if attention_mask is not None attention_mask[:, -query_states.size(1) :]
else None, if attention_mask is not None
else None
),
) )
if q_unpad.dtype != kv_unpad.dtype: if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype) kv_unpad = kv_unpad.to(q_unpad.dtype)

View File

@@ -1,6 +1,7 @@
""" """
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
""" """
from typing import Optional from typing import Optional
import torch import torch

View File

@@ -1,4 +1,5 @@
"""Flash attention monkey patch for mistral model""" """Flash attention monkey patch for mistral model"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import logging import logging
@@ -21,7 +22,10 @@ from transformers.models.mistral.modeling_mistral import (
from transformers.models.mistral.modeling_mistral import ( from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer, MistralDecoderLayer as OriginalMistralDecoderLayer,
) )
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv from transformers.models.mistral.modeling_mistral import (
apply_rotary_pos_emb,
repeat_kv,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
@@ -243,9 +247,11 @@ def flashattn_forward(
# We have disabled _prepare_decoder_attention_mask in LlamaModel # We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask # the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask, key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :] query_padding_mask=(
if attention_mask is not None attention_mask[:, -query_states.size(1) :]
else None, if attention_mask is not None
else None
),
) )
output_unpad = flash_attn_varlen_qkvpacked_func( output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad, qkv_unpad,
@@ -286,9 +292,11 @@ def flashattn_forward(
value_states, value_states,
kvpacked=True, kvpacked=True,
key_padding_mask=attention_mask, key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :] query_padding_mask=(
if attention_mask is not None attention_mask[:, -query_states.size(1) :]
else None, if attention_mask is not None
else None
),
) )
if q_unpad.dtype != kv_unpad.dtype: if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype) kv_unpad = kv_unpad.to(q_unpad.dtype)

View File

@@ -1,6 +1,7 @@
""" """
Patches to support multipack for mixtral Patches to support multipack for mixtral
""" """
import torch import torch

View File

@@ -1,4 +1,5 @@
"""Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune.""" """Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune."""
import glob import glob
import json import json
import logging import logging
@@ -411,7 +412,10 @@ def merge_and_save(
if shard_path.endswith(".safetensors"): if shard_path.endswith(".safetensors"):
in_tensors = st.load_file(str(Path(model_src) / shard_path)) in_tensors = st.load_file(str(Path(model_src) / shard_path))
else: else:
in_tensors = torch.load(Path(model_src) / shard_path) in_tensors = torch.load(
Path(model_src) / shard_path,
weights_only=True, # to prevent arbitrary code execution
)
if "state_dict" in in_tensors: if "state_dict" in in_tensors:
in_tensors = in_tensors["state_dict"] in_tensors = in_tensors["state_dict"]

View File

@@ -17,7 +17,7 @@
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
""" PyTorch StableLM Epoch model. """ """PyTorch StableLM Epoch model."""
import importlib import importlib
import math import math
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
""" """
fix for FSDP optimizer save in trainer w 4.47.0 fix for FSDP optimizer save in trainer w 4.47.0
""" """
import inspect import inspect
import logging import logging

View File

@@ -1,6 +1,7 @@
""" """
Shared utils for the monkeypatches Shared utils for the monkeypatches
""" """
import re import re
from typing import Optional, Tuple from typing import Optional, Tuple

View File

@@ -1,6 +1,7 @@
""" """
Fused MLP layer for incrementally improved training efficiency Fused MLP layer for incrementally improved training efficiency
""" """
import torch import torch
from transformers.models.llama.modeling_llama import LlamaMLP from transformers.models.llama.modeling_llama import LlamaMLP
from xformers.ops import SwiGLU from xformers.ops import SwiGLU

View File

@@ -1,6 +1,7 @@
""" """
Prompt strategies loader for alpaca instruction datasets with system prompts Prompt strategies loader for alpaca instruction datasets with system prompts
""" """
from typing import Generator, Tuple, Union from typing import Generator, Tuple, Union
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy

View File

@@ -1,6 +1,7 @@
""" """
Basic completion text Basic completion text
""" """
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Generator, Optional, Tuple from typing import Any, Dict, Generator, Optional, Tuple

View File

@@ -1,4 +1,5 @@
"""Module containing the classes for Context QA Prompt Tokenization Strategies""" """Module containing the classes for Context QA Prompt Tokenization Strategies"""
from typing import Tuple from typing import Tuple
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy

View File

@@ -1,6 +1,7 @@
""" """
module for DPO style dataset transform strategies module for DPO style dataset transform strategies
""" """
from functools import partial from functools import partial
from ..base import load as load_base from ..base import load as load_base

View File

@@ -33,9 +33,9 @@ def default(
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>" sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>" sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
return sample return sample
@@ -52,9 +52,9 @@ def argilla_chat(
""" """
def transform_fn(sample): def transform_fn(sample):
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample return sample
@@ -78,9 +78,9 @@ def icr(
f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen']}<|im_end|>" sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>" sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample return sample
@@ -100,9 +100,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen']}<|im_end|>" sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>" sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample return sample
@@ -120,9 +120,9 @@ def prompt_pairs(
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen']}<|im_end|>" sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>" sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample return sample
@@ -142,9 +142,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" )
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample return sample

View File

@@ -34,9 +34,9 @@ def default(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>" sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>" sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
return sample return sample
@@ -53,9 +53,9 @@ def argilla_chat(
""" """
def transform_fn(sample): def transform_fn(sample):
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>" sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample return sample
@@ -79,9 +79,9 @@ def icr(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen']}<|eot_id|>" sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>" sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample return sample
@@ -101,9 +101,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen']}<|eot_id|>" sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>" sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample return sample
@@ -121,9 +121,9 @@ def prompt_pairs(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen']}<|eot_id|>" sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>" sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample return sample
@@ -143,9 +143,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
) )
else: else:
sample[ sample["prompt"] = (
"prompt" f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" )
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>" sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>" sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample return sample

Some files were not shown because too many files have changed in this diff Show More