Compare commits

..

40 Commits

Author SHA1 Message Date
Dan Saunders
c649d569b4 simplify by installing no deps 2025-03-21 13:27:54 -04:00
Dan Saunders
b88b389b17 installing axolotl prior to quartodoc build 2025-03-21 16:52:51 +00:00
Dan Saunders
23f0c51d88 Sequence parallelism (#2412)
* adding easy_context as integration for now

* progress on ring attn impl

* progress on ring attn impl

* cleanup

* remove errant file

* fix req

* removing unused code

* updates

* pytest

* update

* updates

* fixes

* precommit fixes

* working multi-group SP

* fixing sample packing

* remove debug logs and simplify

* eval dataloader and sampler changes

* removing some obvious comments

* update config.qmd and rename option

* scoping down problematic import

* another import scoping change

* pernicious Fire CLI bugfix

* isolate cli tests

* actually isolate CLI tests

* gracefully handle no ring-flash-attn

* fix

* fix

* move ring flash attn to extras with flash-attn (#2414)

* removing flash-attn from requirements.txt (in setup.py extras already)

* rename file, delete another

* using field validator instead of model validator

* test fix

* sampler / dataloader refactor

* non-seq2se1 collator fix

* removing print statement

* bugfix

* add SP doc, review comments

* small changes

* review comments, docstrings

* refactors, SP mixin

* small updates

* fix tests

* precommit

* precommit

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Dan Saunders <dan@axolotl.ai>
2025-03-21 12:43:55 -04:00
Dan Saunders
113e9cd193 Autodoc generation with quartodoc (#2419)
* quartodoc integration

* quartodoc progress

* deletions

* Update docs/.gitignore to exclude auto-generated API documentation files

* Fix

* more autodoc progress

* moving reference up near the top of the sidebar

* fix broken link

* update to reflect recent changes

* pydantic models refactor + add to autodoc + fixes

* fix

* shrinking header sizes

* fix accidental change

* include quartodoc build step

* update pre-commit version

* update pylint

* pre-commit

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
2025-03-21 12:26:47 -04:00
NanoCode012
61825a464a chore(doc): add explanation on fsdp_transformer_layer_cls_to_wrap (#2429) [skip ci] 2025-03-21 11:59:22 -04:00
Dan Saunders
c907ac173e adding pre-commit auto-update GH action and bumping plugin versions (#2428)
* adding pre-commit auto-update GH action and bumping plugin versions

* running updated pre-commit plugins

* sorry to revert, but pylint complained

* Update .pre-commit-config.yaml

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-03-21 11:02:43 -04:00
salman
187227d837 Fixing KTO+QLoRA+multi-GPU (#2420)
* WIP

* removing artifacts

* adding error

* adding adapter check

* linting

* simplifying check

* linting v2

* config fix -___-
2025-03-21 10:18:28 -04:00
NanoCode012
f8de8bb4f2 chore(doc): add instructions on adding custom integrations (#2422) [skip ci]
* chore(doc): add instructions on adding custom integrations

* chore: add warning help

* feat: add note about integration path

* fix: adjust text per suggestion
2025-03-21 10:18:01 -04:00
hugo
8e604848a4 add run on novita ai (#2421) [skip ci]
* add run on novita ai

* Revert "add run on novita ai"

This reverts commit 4d5df1ac6b.

* add run axolotl on novita ai
2025-03-21 10:17:47 -04:00
Wing Lian
aae4337f40 add 12.8.1 cuda to the base matrix (#2426)
* add 12.8.1 cuda to the base matrix

* use nightly

* bump deepspeed and set no binary

* deepspeed binary fixes hopefully

* install deepspeed by itself

* multiline fix

* make sure ninja is installed

* try with reversion of packaging/setuptools/wheel install

* use license instead of license-file

* try rolling back packaging and setuptools versions

* comment out license for validation for now

* make sure packaging version is consistent

* more parity across tests and docker images for packaging/setuptools
2025-03-21 10:17:25 -04:00
Wing Lian
38df5a36ea bump HF versions except for trl (#2427) 2025-03-20 10:22:05 -04:00
Wing Lian
4d92a68a96 use default torch fused adamw optimizer as default as adamw_hf is deprecated (#2425)
* use default torch fused adamw optimizer as default as adamw_hf is deprecated

* make sure to have latest packaging installed

* bump packagingin requirements.txt too
2025-03-19 23:58:33 -04:00
SicariusSicariiStuff
85147ec430 Update README.md (#2360)
* Update README.md

wheel is needed

* feat: add ninja, setuptools, packing to installation steps

* fix: add missing instruction

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-03-17 08:39:17 -04:00
NanoCode012
51cd409488 Feat: minor docs improvements for RLHF and faq on embeddings (#2401) [skip ci]
* feat: add doc on shrink_embeddings and custom calling

* chore: rename inference doc

* fix: clarify same config is used for all cli

* chore: rearrange order inference qmd

* feat: add simpo to doc

* fix: update defaults

* feat: add rl configs to doc

* fix: ensure beta consistent with trl.beta

* fix: clarify about lora/fft

* chore: rename title

* chore: fix language

* feat: move config reference higher

* Update docs/getting-started.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

* Update docs/rlhf.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-03-17 08:39:04 -04:00
NanoCode012
7235123d44 chore(docs): add cookbook/blog link to docs (#2410) [skip ci] 2025-03-17 08:38:19 -04:00
Wing Lian
4f5eb42a73 remove reference to deprecated import (#2407) 2025-03-15 08:49:41 -04:00
Wing Lian
fbe54be6b8 only validate hf user token on rank 0 (#2408) 2025-03-13 23:29:06 -04:00
Wing Lian
04f6324833 build cloud images with torch 2.6.0 (#2413)
* build cloud images with torch 2.6.0

* nightlies too
2025-03-13 23:28:51 -04:00
Wing Lian
f0072f3b9d use max of 32 dataset processes if not explicit (#2403)
* use max of 32 dataset processes if not explicit

* change alternate min val for consistency
2025-03-11 12:02:58 -04:00
Wing Lian
59899b9817 pass additional info for fix untrained tokens when using distributed + offloading (#2388)
* pass additional info for fix untrained tokens when using distributed + offloading

* use latest version of vendored lib

* use v0.0.5 of contribs lgpl

* fix for no bad tokens and add tests

* use release

* add multigpu test too

* make sure the multigpu zero3 test actually uses zero3
2025-03-11 12:02:43 -04:00
NanoCode012
4a736986fa fix(modal): add git pull when getting branch files (#2399) 2025-03-10 15:14:41 -04:00
Wing Lian
5d0f110a3b include iproute2 and nvtop in cloud image (#2393) 2025-03-10 15:13:38 -04:00
NanoCode012
83f8698b8a fix: create mount folder on modal if not exist (#2390) 2025-03-10 16:27:42 +07:00
xzuyn
60a11a6410 Use Latest Cut Cross Entropy (#2392)
* Update __init__.py

* Update README.md

* Update cutcrossentropy_install.py

* add test
2025-03-10 16:26:40 +07:00
NanoCode012
46a045e528 chore(doc): add faq when having no default chat_template (#2398)
* chore(doc): add faq when having no default chat_template

* Update docs/dataset-formats/conversation.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

* Update docs/faq.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-03-10 16:25:50 +07:00
NanoCode012
3b477e08a0 feat(doc): add more info on RewardModel datasets (#2391)
* fix: reduce title size

* feat(doc): add rm dataset info

* Update docs/reward_modelling.qmd following suggestion

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-03-10 16:25:31 +07:00
NanoCode012
16dc6ee68d refactor: trl grpo configs to have descriptions (#2386)
* refactor: trl grpo configs to have descriptions

* chore: caps
2025-03-07 08:58:53 -05:00
Wing Lian
fa7c79b3b9 remove lion-pytorch as it's already handled upstream (#2389) 2025-03-07 08:58:15 -05:00
Wing Lian
ae66374156 Optimizer refactor and add Muon support (#2367)
* add muon optimizer

optimizer_cls_and_kwargs is on trainer_kwargs
only add adamw_kwargs if they're non-null
fix mocks
better handling of override and check the optimizer
unwrap optimizer

* fix import
2025-03-06 11:49:19 -05:00
Wing Lian
5e21b1a9da various fixes 20250305 (#2384)
* various validation fixes

* fix check for non-truthy value
2025-03-06 11:48:44 -05:00
mhenrichsen
575e5f28ec Update Tokenizer Overrides Handling in models.py (#1549)
* override special tokens mock code

* fix(doc): remove duplicate config

* feat: replace added_tokens in tokenizer and add test

* make sure to run tokenizer modification on rank 0 only

* use is local main process instead

* feat: rename config

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-03-05 11:15:12 -05:00
xzuyn
0134093acc Add REX LR Scheduler (#2380)
* Update trainer_builder.py

* Update base.py

* Update __init__.py

* Update base.py

* Update base.py

* Update config.qmd

* Update base.py

* Update base.py

* Update base.py

* Update base.py

* Update base.py

* Update base.py

* Update base.py

* lint

* lint

* lint

* lint

* lint

* lint

* Update base.py

* Update base.py

* lint

* Update base.py

* Update base.py

* Move RexLR to `schedulers.py`

* Remove RexLR from `base.py`

* Fix tooltip formatting

* lint

* Create test_schedulers.py

* Use a default optimizer in test

* lint

* lint

* Add `warmup_steps` and `cosine_min_lr_ratio` to test

* lint
2025-03-05 10:26:11 -05:00
NanoCode012
d4de93a7bb feat(grpo): add reward_weights config and refactor (#2365) 2025-03-05 10:02:08 -05:00
NanoCode012
c8191394e9 fix(doc): add missing low_cpu_mem_usage config to docs (#2369) [skip ci] 2025-03-05 10:01:44 -05:00
NanoCode012
f18231c653 chore(doc): add clarification about mpi4py error on single gpu deepspeed (#2383) [skip ci]
* chore(doc): add clarification about mpi4py error on single gpu deepspeed

* fix: lint
2025-03-05 10:01:28 -05:00
NanoCode012
9ed4f6b3aa feat(doc): document drop_system_message and clarify limitation (#2381) [skip ci] 2025-03-05 10:01:16 -05:00
NanoCode012
05dddfc41d feat(doc): add docker images explanation (#2379) [skip ci]
* feat(doc): add docker images explanation

* chore: add link to dockerhub
2025-03-05 10:01:00 -05:00
NanoCode012
8e30917440 chore(docs): remove phorm (#2378) [skip ci] 2025-03-05 10:00:50 -05:00
NanoCode012
d883b11b6f fix(doc): add installation for cce to docs (#2375) [skip ci]
* fix(doc): add installation for cce to docs

* fix: format
2025-03-05 10:00:39 -05:00
Dan Saunders
f4910dd2ea train.py refactor (#2371)
* refactor train.py

* updates

* update

* combine like functions

* review comments
2025-03-05 08:58:33 -05:00
235 changed files with 4932 additions and 4545 deletions

View File

@@ -40,6 +40,12 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: nightly
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -61,7 +67,7 @@ jobs:
uses: docker/build-push-action@v4
with:
context: .
file: ./docker/Dockerfile-base
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || './docker/Dockerfile-base' }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -20,9 +20,12 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: install dependencies
- name: Install dependencies
run: |
python3 -m pip install jupyter
python3 -m pip install jupyter quartodoc
python3 -m pip install -e . --no-deps
- name: Build autodoc
run: quartodoc build
- name: Publish to GitHub Pages (and render)
uses: quarto-dev/quarto-actions/publish@v2
with:

View File

@@ -88,6 +88,11 @@ jobs:
pytorch: 2.5.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -80,6 +80,11 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -0,0 +1,49 @@
name: Pre-commit auto-update
on:
schedule:
- cron: '0 0 * * 0' # Run weekly
workflow_dispatch: # Manual kickoff
jobs:
auto-update:
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Update pre-commit hooks
id: update
run: |
pip install pre-commit
pre-commit autoupdate
if [[ -n $(git status --porcelain) ]]; then
echo "changes=true" >> $GITHUB_OUTPUT
git diff .pre-commit-config.yaml > pre-commit-update.diff
fi
- name: Create Pull Request
if: steps.update.outputs.changes == 'true'
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
branch: update/pre-commit-hooks
delete-branch: true
title: "chore: update pre-commit hooks"
commit-message: "chore: update pre-commit hooks"
body: |
Automated PR to update pre-commit hooks to their latest versions.
<details>
<summary>Changes:</summary>
```diff
${{ steps.update.outputs.diff }}
```
</details>

View File

@@ -40,7 +40,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install wheel packaging
pip3 install wheel packaging==23.2
pip3 install --no-build-isolation -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt

View File

@@ -42,7 +42,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
@@ -59,7 +59,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install --upgrade packaging==23.2
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh

View File

@@ -74,7 +74,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
@@ -98,8 +98,9 @@ jobs:
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v tests/patched/
pytest -v tests/cli/
- name: cleanup pip cache
run: |
@@ -147,7 +148,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools setuptools_scm build wheel
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
- name: Install PyTorch
run: |
@@ -172,8 +173,9 @@ jobs:
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v tests/patched/
pytest -v tests/cli/
- name: cleanup pip cache
run: |

4
.gitignore vendored
View File

@@ -181,6 +181,10 @@ prepared-datasets/
submit.sh
*.out*
# Quartodoc generated files
objects.json
site_libs/
typings/
out/

View File

@@ -3,7 +3,7 @@ default_language_version:
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v5.0.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
@@ -11,23 +11,23 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/psf/black
rev: 23.3.0
rev: 25.1.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.12.0
rev: 6.0.1
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 6.1.0
rev: 7.1.2
hooks:
- id: flake8
- repo: https://github.com/PyCQA/pylint
rev: v3.3.0
- repo: https://github.com/pylint-dev/pylint
rev: v3.3.6
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0
rev: v1.15.0
hooks:
- id: mypy
additional_dependencies:
@@ -36,7 +36,7 @@ repos:
'pydantic>=2.5.3',
]
- repo: https://github.com/PyCQA/bandit
rev: 1.7.5
rev: 1.8.3
hooks:
- id: bandit
args: [

View File

@@ -19,9 +19,6 @@
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
</a>
</p>
Axolotl is a tool designed to streamline post-training for various AI models.
@@ -58,6 +55,7 @@ Features:
### Installation
```bash
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs
@@ -99,6 +97,7 @@ That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
- [API Reference](https://axolotl-ai-cloud.github.io/axolotl/docs/api/) - Auto-generated code documentation
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
## 🤝 Getting Help

View File

@@ -1,6 +1,178 @@
project:
type: website
quartodoc:
dir: docs/api
package: axolotl
title: API Reference
parser: google
sections:
- title: Core
desc: Core functionality for training
contents:
- train
- evaluate
- datasets
- convert
- prompt_tokenizers
- logging_config
- core.trainer_builder
- core.training_args
- core.chat.messages
- core.chat.format.chatml
- core.chat.format.llama3x
- core.chat.format.shared
- core.datasets.chat
- core.datasets.transforms.chat_builder
- title: CLI
desc: Command-line interface
contents:
- cli.main
- cli.train
- cli.evaluate
- cli.args
- cli.checks
- cli.config
- cli.inference
- cli.merge_lora
- cli.merge_sharded_fsdp_weights
- cli.preprocess
- cli.sweeps
- cli.utils
- cli.cloud.base
- cli.cloud.modal_
- title: Trainers
desc: Training implementations
contents:
- core.trainers.base
- core.trainers.trl
- core.trainers.dpo.trainer
- core.trainers.grpo.trainer
- title: Prompt Strategies
desc: Prompt formatting strategies
contents:
- prompt_strategies.base
- prompt_strategies.chat_template
- prompt_strategies.alpaca_chat
- prompt_strategies.alpaca_instruct
- prompt_strategies.alpaca_w_system
- prompt_strategies.user_defined
- prompt_strategies.llama2_chat
- prompt_strategies.completion
- prompt_strategies.input_output
- prompt_strategies.stepwise_supervised
- prompt_strategies.metharme
- prompt_strategies.orcamini
- prompt_strategies.pygmalion
- prompt_strategies.messages.chat
- prompt_strategies.dpo.chat_template
- prompt_strategies.dpo.llama3
- prompt_strategies.dpo.chatml
- prompt_strategies.dpo.zephyr
- prompt_strategies.dpo.user_defined
- prompt_strategies.dpo.passthrough
- prompt_strategies.kto.llama3
- prompt_strategies.kto.chatml
- prompt_strategies.kto.user_defined
- prompt_strategies.orpo.chat_template
- prompt_strategies.bradley_terry.llama3
- title: Kernels
desc: Low-level performance optimizations
contents:
- kernels.lora
- kernels.geglu
- kernels.swiglu
- kernels.quantize
- kernels.utils
- title: MonkeyPatches
desc: Runtime patches for model optimizations
contents:
- monkeypatch.llama_attn_hijack_flash
- monkeypatch.llama_attn_hijack_xformers
- monkeypatch.mistral_attn_hijack_flash
- monkeypatch.multipack
- monkeypatch.relora
- monkeypatch.llama_expand_mask
- monkeypatch.lora_kernels
- monkeypatch.utils
- monkeypatch.btlm_attn_hijack_flash
- monkeypatch.llama_patch_multipack
- monkeypatch.stablelm_attn_hijack_flash
- monkeypatch.trainer_fsdp_optim
- monkeypatch.transformers_fa_utils
- monkeypatch.unsloth_
- monkeypatch.attention.mllama
- monkeypatch.data.batch_dataset_fetcher
- monkeypatch.mixtral
- title: Utils
desc: Utility functions
contents:
- utils.models
- utils.tokenization
- utils.chat_templates
- utils.lora
- utils.lora_embeddings
- utils.model_shard_quant
- utils.bench
- utils.freeze
- utils.trainer
- utils.schedulers
- utils.distributed
- utils.dict
- utils.optimizers.adopt
- utils.data.pretraining
- utils.data.sft
- utils.gradient_checkpointing.unsloth
- title: Schemas
desc: Pydantic data models for Axolotl config
contents:
- utils.schemas.config
- utils.schemas.model
- utils.schemas.training
- utils.schemas.datasets
- utils.schemas.peft
- utils.schemas.trl
- utils.schemas.integrations
- utils.schemas.enums
- utils.schemas.utils
- title: Integrations
desc: Third-party integrations and extensions
contents:
- integrations.base
- integrations.cut_cross_entropy.args
- integrations.grokfast.optimizer
- integrations.kd.trainer
- integrations.liger.args
- integrations.lm_eval.args
- integrations.spectrum.args
- title: Common
desc: Common utilities and shared functionality
contents:
- common.architectures
- common.const
- common.datasets
- title: Models
desc: Custom model implementations
contents:
- models.mamba.modeling_mamba
- title: Data Processing
desc: Data processing utilities
contents:
- utils.collators.core
- utils.collators.batching
- utils.collators.mamba
- utils.collators.mm_chat
- utils.samplers.multipack
- title: Callbacks
desc: Training callbacks
contents:
- utils.callbacks.perplexity
- utils.callbacks.profiler
- utils.callbacks.lisa
- utils.callbacks.mlflow_
- utils.callbacks.comet_
website:
title: "Axolotl"
description: "We make fine-tuning accessible, scalable, and fun"
@@ -32,14 +204,18 @@ website:
contents:
- docs/getting-started.qmd
- docs/installation.qmd
- docs/cli.qmd
- docs/inference.qmd
- docs/cli.qmd
- docs/config.qmd
- text: "API Reference"
href: docs/api
- section: "Dataset Formats"
contents: docs/dataset-formats/*
- section: "Deployments"
contents:
- docs/docker.qmd
- docs/multi-gpu.qmd
- docs/multi-node.qmd
- docs/ray-integration.qmd
@@ -73,12 +249,27 @@ website:
- docs/debugging.qmd
- docs/nccl.qmd
- section: "Reference"
contents:
- docs/config.qmd
format:
html:
theme: darkly
css: styles.css
toc: true
# Enable better handling of line breaks in markdown
preserve-tabs: true
html-math-method: mathjax
# Improved markdown processing options
md-extensions:
- markdown_it
- def_list
- attr_list
- fenced_divs
- tables
- html_admonition
- lineblocks
- fancy_lists
# Control whitespace handling
whitespace: preserve
# Process newlines in paragraphs
wrap: preserve
# Better line break handling
preserve-linebreaks: true

View File

@@ -31,10 +31,11 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -3,9 +3,10 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli /workspace/axolotl/tests/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
pytest -v --durations=10 /workspace/axolotl/tests/cli
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ --ignore=tests/cli /workspace/axolotl/tests/e2e/

View File

@@ -1,6 +1,7 @@
"""
modal application to run axolotl gpu tests in Modal
"""
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code
import os

View File

@@ -1,4 +1,5 @@
"""Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code
import os

View File

@@ -28,7 +28,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"

View File

@@ -0,0 +1,39 @@
ARG CUDA_VERSION="12.8.1"
ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="nightly"
ARG CUDA="128"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10

View File

@@ -14,7 +14,7 @@ COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux && \
RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \

2
docs/.gitignore vendored
View File

@@ -1,2 +1,4 @@
/.quarto/
_site/
/api/*.qmd
/api/*.html

View File

@@ -1,5 +1,5 @@
---
title: "CLI Reference"
title: "Command Line Interface (CLI)"
format:
html:
toc: true

View File

@@ -1,5 +1,5 @@
---
title: Config options
title: Config Reference
description: A complete list of all configuration options.
---
@@ -30,6 +30,11 @@ tokenizer_legacy:
# Resize the model embeddings when new tokens are added to multiples of 32
# This is reported to improve training speed on some models
resize_token_embeddings_to_32x:
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
shrink_embeddings:
# Whether to load the model with randomly initialized weights. Useful for
# pre-training a model from scratch or debugging purposes.
random_init_weights:
# (Internal use only)
# Used to identify which the model is based on
@@ -83,6 +88,12 @@ gpu_memory_limit: 20GiB
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
lora_on_cpu: true
# List[str]. Add plugins to extend the pipeline.
# See `src/axolotl/integrations` for the available plugins or doc below for more details.
# https://axolotl-ai-cloud.github.io/axolotl/docs/custom_integrations.html
plugins:
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# A list of one or more datasets to finetune the model with
datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
@@ -154,8 +165,6 @@ datasets:
content: value
# ...
message_property_mappings:
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles:
user: ["human", "user"]
@@ -163,6 +172,12 @@ datasets:
system: ["system"]
tool: ["tool"]
# Optional[bool]. Whether to drop the system turn from the dataset. Only works with chat_template.
# This does not drop the default system message from chat_template if it exists. If you wish to,
# we recommend using a custom jinja template with the default system message removed or
# adding a system turn with empty content.
drop_system_message:
# IMPORTANT: The following fields determine which parts of the conversation to train on.
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
# See examples at `docs/dataset-formats/conversation.qmd`
@@ -201,10 +216,46 @@ test_datasets:
data_files:
- /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto'
# use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'
rl:
# whether to perform weighting if doing DPO training. Boolean.
dpo_use_weighting:
rl_beta: # Optional[float]. The beta parameter for the RL training.
# dpo
dpo_use_weighting: # Optional[bool]. Whether to perform weighting.
rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper.
# orpo
orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping.
# kto
kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss.
kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss.
# simpo
cpo_alpha: 1.0 # Weight of the BC regularizer
simpo_gamma: 0.5 # Target reward margin for the SimPO loss
# grpo
trl:
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
vllm_device: # Optional[str]. Device to use for VLLM.
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
vllm_dtype: # Optional[str]. Data type for VLLM.
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir.
reward_weights: # Optional[list[float]]. List of reward weights for the reward functions.
num_generations: # Optional[int]. Number of generations to sample.
log_completions: # Optional[bool]. Whether to log completions.
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
# reward modelling: `True` or `False`
reward_model:
@@ -222,13 +273,13 @@ process_reward_model:
chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null
# Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Changes the default system message. Currently only supports chatml.
default_system_message: You are a helpful assistant. Please give a long and detailed answer.
# Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
# Push prepared dataset to hub
push_dataset_to_hub: # repo path
push_dataset_to_hub: # Optional[str] repo_org/repo_name
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# if not set.
dataset_processes: # defaults to os.cpu_count() if not set
@@ -445,7 +496,7 @@ gradient_checkpointing: false
early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
@@ -528,6 +579,8 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
sdp_attention:
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
s2_attention:
# Optional[bool]. Whether to use low_cpu_mem_usage
low_cpu_mem_usage:
# Resume from a specific checkpoint dir
resume_from_checkpoint:
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
@@ -548,6 +601,13 @@ special_tokens:
# Add extra tokens.
tokens:
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
# Can be checked if they exist in tokenizer.json added_tokens.
added_tokens_overrides: # Dict[int, str]
# 128041: "<|im_start|>"
# 128042: "<|im_end|>"
# FSDP
fsdp:
fsdp_config:
@@ -560,6 +620,14 @@ ddp_timeout:
ddp_bucket_cap_mb:
ddp_broadcast_buffers:
# Sequence parallelism
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
# subsequences, or set to 4 to split into four equal-sized subsequences.
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
sequence_parallel_degree:
# Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path:

View File

@@ -55,3 +55,47 @@ sections = [
for section_name, folder_name in sections:
print(print_section(section_name, folder_name))
```
## Adding a new integration
Plugins can be used to customize the behavior of the training pipeline through [hooks](https://en.wikipedia.org/wiki/Hooking). See [`axolotl.integrations.BasePlugin`](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/base.py) for the possible hooks.
To add a new integration, please follow these steps:
1. Create a new folder in the `src/axolotl/integrations` directory.
2. Add any relevant files (`LICENSE`, `README.md`, `ACKNOWLEDGEMENTS.md`, etc.) to the new folder.
3. Add `__init__.py` and `args.py` files to the new folder.
- `__init__.py` should import the integration and hook into the appropriate functions.
- `args.py` should define the arguments for the integration.
4. (If applicable) Add CPU tests under `tests/integrations` or GPU tests under `tests/e2e/integrations`.
::: {.callout-tip}
See [src/axolotl/integrations/cut_cross_entropy](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/cut_cross_entropy) for a minimal integration example.
:::
::: {.callout-warning}
If you could not load your integration, please ensure you are pip installing in editable mode.
```bash
pip install -e .
```
and correctly spelled the integration name in the config file.
```yaml
plugins:
- axolotl.integrations.your_integration_name.YourIntegrationPlugin
```
:::
::: {.callout-note}
It is not necessary to place your integration in the `integrations` folder. It can be in any location, so long as it's installed in a package in your python env.
See this repo for an example: [https://github.com/axolotl-ai-cloud/diff-transformer](https://github.com/axolotl-ai-cloud/diff-transformer)
:::

View File

@@ -74,6 +74,10 @@ datasets:
train_on_eos:
```
::: {.callout-tip}
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
:::
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml

View File

@@ -129,6 +129,7 @@ You can mix and match within each approach or across approaches to train a model
We suggest this approach when you want to bring your own tokenized dataset.
Axolotl expects the dataset to have three keys:
- `input_ids`: from tokenizing formatted prompt
- `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]`
- `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`.

View File

@@ -6,7 +6,7 @@ description: How datasets are processed
## Overview
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
the [dataset format](docs/dataset-formats) and prompt strategies to:
the [dataset format](dataset-formats) and prompt strategies to:
- parse the dataset based on the *dataset format*
- transform the dataset to how you would interact with the model based on the *prompt strategy*

140
docs/docker.qmd Normal file
View File

@@ -0,0 +1,140 @@
---
title: "Docker"
format:
html:
toc: true
toc-depth: 4
---
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
## Base
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
#### Image
```
axolotlai/axolotl-base
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)
#### Tags format
```bash
main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
```
Tags examples:
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
- `main-base-py3.11-cu124-2.4.1`
## Main
The main image is the image that is used to run Axolotl. It is based on the `axolotlai/axolotl-base` image and includes the Axolotl codebase, dependencies, and more.
#### Image
```
axolotlai/axolotl
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
#### Tags format {#sec-main-tags}
```bash
# on push to main
main-py{python_version}-cu{cuda_version}-{pytorch_version}
# latest main (currently torch 2.5.1, python 3.11, cuda 12.4)
main-latest
# nightly build
{branch}-{date_in_YYYYMMDD}-py{python_version}-cu{cuda_version}-{pytorch_version}
# tagged release
{version}
```
:::{.callout-tip}
There may be some extra tags appended to the image, like `-vllm` which installs those packages.
:::
Tags examples:
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-py3.11-cu124-2.4.1`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu124-2.5.1`
- `main-20250303-py3.11-cu124-2.4.1`
- `0.7.1`
## Cloud
The cloud image is the image that is used to run Axolotl in the cloud. It is based on the `axolotlai/axolotl` image and sets ENV variables like HuggingFace cache directories for volume mounts, tmux, and more for different cloud providers.
:::{.callout-tip}
Jupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variables to disable it.
:::
#### Image
```
axolotlai/axolotl-cloud
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)
#### Tags format
This uses the same tags as the [`main` image](#sec-main-tags).
#### Environment variables
- `JUPYTER_DISABLE`: Disable Jupyter lab.
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
- `PUBLIC_KEY`: Add a public key for the SSH service.
- `SSH_KEY`: Add a private key for the SSH service.
#### Volume mounts
:::{.callout-tip}
We recommend mounting volumes to `/workspace/data` for data persistence. `/workspace/axolotl` contains the source code and is ephemeral.
:::
- `/workspace/data/axolotl-artifacts`: Directory to store Axolotl artifacts.
- `/workspace/data/huggingface-cache`: Directory to store HuggingFace cache.
## Cloud-no-tmux
This is the same as the [`cloud` image](#sec-cloud) but without tmux.
#### Image
```
axolotlai/axolotl-cloud-term
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud-term)
:::{.callout-note}
The naming may be a bit confusing as it has `-term` appended to the end.
:::
#### Tags format
This uses the same tags as the [`cloud` image](#sec-cloud-tags).

View File

@@ -19,12 +19,28 @@ description: Frequently asked questions
**Q: AttributeError: 'DummyOptim' object has no attribute 'step'**
> A: You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
**Q: ModuleNotFoundError: No module named 'mpi4py' using single GPU with deepspeed**
> A: You may be using deepspeed with single gpu. Please remove the `deepspeed:` section in the yaml file or `--deepspeed` CLI flag.
**Q: The codes is stuck on saving preprocessed datasets.**
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
**Q: Received mismatch error on merge adapters / loading adapters between torch.Size of checkpoint and model.**
> A: This is likely due to vocab size mismatch. By default, Axolotl expands the model's embeddings if the tokenizer has more tokens than the model. Please use the `axolotl merge-lora` command to merge the adapters instead of using your own scripts.
> On the other hand, if the model has more tokens than the tokenizer, Axolotl does not shrink the model's embeddings unless `shrink_embeddings: true` is set in the config.
**Q: How to call Axolotl via custom python scripts?**
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
**Q: How to know the value to use for `fsdp_transformer_layer_cls_to_wrap`?**
> A: This is the class name of the transformer layer to wrap with FSDP. For example, for `LlamaForCausalLM`, the value is `LlamaDecoderLayer`. To find this for a specific model, check the model's `PreTrainedModel` definition and look for `_no_split_modules` variable in the `modeling_<model_name>.py` file within `transformers` library.
### Chat templates
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
@@ -50,3 +66,7 @@ description: Frequently asked questions
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.

View File

@@ -36,7 +36,9 @@ The YAML configuration file controls everything about your training. Here's what
```yaml
base_model: NousResearch/Llama-3.2-1B
# hub_model_id: username/custom_model_name
load_in_8bit: true
adapter: lora
datasets:
- path: teknium/GPT4-LLM-Cleaned
@@ -44,11 +46,15 @@ datasets:
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
```
::: {.callout-tip}
`load_in_8bit: true` and `adapter: lora` enables LoRA adapter finetuning.
- To perform Full finetuning, remove these two lines.
- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.
:::
See our [Config options](config.qmd) for more details.
### Training {#sec-training}
@@ -56,7 +62,7 @@ See our [Config options](config.qmd) for more details.
When you run `axolotl train`, Axolotl:
1. Downloads the base model
2. (If specified) applies LoRA adapter layers
2. (If specified) applies QLoRA/LoRA adapter layers
3. Loads and processes the dataset
4. Runs the training loop
5. Saves the trained model and / or LoRA weights
@@ -69,6 +75,8 @@ Let's modify the example for your own data:
```yaml
base_model: NousResearch/Nous-Hermes-llama-1b-v1
load_in_8bit: true
adapter: lora
# Training settings
@@ -104,8 +112,6 @@ format):
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
```
Please consult the supported [Dataset Formats](dataset-formats/) for more details.
3. Run the training:
```bash

View File

@@ -1,5 +1,5 @@
---
title: "Inference"
title: "Inference and Merging"
format:
html:
toc: true
@@ -9,10 +9,14 @@ execute:
enabled: false
---
This guide covers how to use your trained models for inference, including model loading, interactive testing, and common troubleshooting steps.
This guide covers how to use your trained models for inference, including model loading, interactive testing, merging adapters, and common troubleshooting steps.
## Quick Start {#sec-quickstart}
::: {.callout-tip}
Use the same config used for training on inference/merging.
:::
### Basic Inference {#sec-basic}
::: {.panel-tabset}

View File

@@ -22,6 +22,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
### PyPI Installation (Recommended) {#sec-pypi}
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
```
@@ -37,7 +38,7 @@ For the latest features between releases:
```{.bash}
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging ninja
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
@@ -65,6 +66,8 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
```
:::
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
## Cloud Environments {#sec-cloud}
### Cloud GPU Providers {#sec-cloud-gpu}
@@ -76,6 +79,7 @@ For providers supporting Docker:
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
- [Novita](https://novita.ai/gpus-console?templateId=311)
### Google Colab {#sec-colab}
@@ -105,7 +109,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
2. Install PyTorch: https://pytorch.org/get-started/locally/
3. Install Axolotl:
```{.bash}
pip3 install packaging
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
4. (Optional) Login to Hugging Face:

View File

@@ -66,6 +66,10 @@ logic to be compatible with more of them.
</details>
::: {.callout-tip}
Check out our [LoRA optimizations blog](https://axolotlai.substack.com/p/accelerating-lora-fine-tuning-with).
:::
## Usage
These optimizations can be enabled in your Axolotl config YAML file. The

View File

@@ -28,8 +28,23 @@ val_set_size: 0.1
eval_steps: 100
```
Bradley-Terry chat templates expect single-turn conversations in the following format:
```json
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
```
### Process Reward Models (PRM)
::: {.callout-tip}
Check out our [PRM blog](https://axolotlai.substack.com/p/process-reward-models).
:::
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
```yaml
base_model: Qwen/Qwen2.5-3B
@@ -45,3 +60,5 @@ datasets:
val_set_size: 0.1
eval_steps: 100
```
Please see [stepwise_supervised](dataset-formats/stepwise_supervised.qmd) for more details on the dataset format.

View File

@@ -3,6 +3,7 @@ title: "RLHF (Beta)"
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
@@ -297,7 +298,7 @@ The input format is a simple JSON input with customizable fields based on the ab
### IPO
As IPO is just DPO with a different loss function, all supported options for DPO works here.
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
```yaml
rl: ipo
@@ -343,8 +344,9 @@ ORPO supports the following types with the following dataset format:
```yaml
rl: kto
rl_beta: 0.5
kto_desirable_weight: 0.2
rl_beta: 0.1 # default
kto_desirable_weight: 1.0 # default
kto_undesirable_weight: 1.0 # default
remove_unused_columns: false
@@ -496,6 +498,10 @@ The input format is a simple JSON input with customizable fields based on the ab
### GRPO
::: {.callout-tip}
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
:::
GRPO uses custom reward functions and transformations. Please have them ready locally.
For ex, to load OpenAI's GSM8K and use a random reward for completions:
@@ -528,6 +534,7 @@ trl:
vllm_gpu_memory_utilization: 0.15
num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
reward_weights: [1.0]
datasets:
- path: openai/gsm8k
name: main
@@ -536,6 +543,21 @@ datasets:
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
### SimPO
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
```yaml
rl: simpo
rl_beta: 0.1 # default in CPOTrainer
cpo_alpha: 1.0 # default in CPOTrainer
simpo_gamma: 0.5 # default in CPOTrainer
```
This method uses the same dataset format as [DPO](#dpo).
### Using local dataset files
```yaml

View File

@@ -0,0 +1,90 @@
---
title: Sequence Parallelism
description: Train with long sequences split across multiple GPUs.
---
# Sequence Parallelism
Sequence parallelism is a technique that splits sequences across multiple GPUs,
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
GPU processes a different portion of the sequence, and the results are aggregated
through a ring communication pattern.
## When to Use Sequence Parallelism
Use sequence parallelism when:
- You need to train with sequence lengths that don't fit into a single GPU's memory
- You have multiple GPUs available
- You're experiencing OOM (Out Of Memory) errors with long sequences
## Configuration
To enable sequence parallelism, add the following to your configuration file:
```yaml
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
- With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4
## Implementation Details
When sequence parallelism is enabled:
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
4. The trainer uses special ring communication patterns for attention operations
## Requirements
To use sequence parallelism, you need:
- Multiple GPUs (at least 2)
- The `ring-flash-attn` package. Install with:
- `pip install axolotl[ring-flash-attn]` (preferred)
- `pip install ring-flash-attn>=0.1.4`
## Limitations
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
- May have a small performance overhead due to communication between GPUs
## Example
```yaml
# Example config with sequence parallelism
base_model: meta-llama/Llama-3-8B-Instruct
sequence_len: 8192
sequence_parallel_degree: 2 # Split each sequence into 4 parts
flash_attention: true # Required with sequence parallelism
...
```
This will train the Llama 3 8B model with 8K context length, with each sequence split
into 2 subsequences of length 4096 across 2 GPUs.
## Sample Packing with Sequence Parallelism
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
1. Samples are first packed together
2. The packed sequences are then divided across GPUs in the sequence parallel group
3. Position IDs are automatically adjusted to maintain proper relative positions
## Effect on Batch Size
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases
For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4

View File

@@ -1,59 +0,0 @@
---
title: Telemetry
description: A description of the opt-out telemetry implementation in Axolotl.
---
# Telemetry in Axolotl
Axolotl implements anonymous telemetry to help maintainers understand how the library
is used and where users encounter issues. This data helps prioritize features, optimize
performance, and fix bugs.
## Data Collection
We collect:
- System info: OS, Python version, Axolotl version, PyTorch version, Transformers
version, etc.
- Hardware info: CPU count, memory, GPU count and models
- Runtime metrics: Training progress, memory usage, timing information
- Usage patterns: Models (from a whitelist) and configurations used
- Error tracking: Stack traces and error messages (sanitized to remove personal
information)
No personally identifiable information (PII) is collected.
## Implementation
Telemetry is implemented using PostHog and consists of:
- `axolotl.telemetry.TelemetryManager`: A singleton class that initializes the
telemetry system and provides methods for tracking events.
- `axolotl.telemetry.errors.send_errors`: A decorator that captures exceptions and
sends sanitized stack traces.
- `axolotl.telemetry.runtime_metrics.RuntimeMetricsTracker`: A class that tracks
runtime metrics during training.
- `axolotl.telemetry.callbacks.TelemetryCallback`: A Trainer callback that sends
runtime metrics telemetry.
The telemetry system will block training startup for 15 seconds to ensure users are
aware of data collection, unless telemetry is explicitly enabled or disabled.
## Opt-Out Mechanism
Telemetry is **enabled by default** on an opt-out basis. To disable it, set either:
- `AXOLOTL_DO_NOT_TRACK=1` (Axolotl-specific)
- `DO_NOT_TRACK=1` (Global standard; see https://consoledonottrack.com/)
To acknowledge and explicitly enable telemetry (and remove the warning message), set:
`AXOLOTL_DO_NOT_TRACK=0`.
## Privacy
- All path-like config information is automatically redacted from telemetry data
- Model information is only collected for whitelisted organizations
- See `axolotl/telemetry/whitelist.yaml` for the set of whitelisted organizations
- Each run generates a unique anonymous ID
- This allows us to link different telemetry events in a single same training run
- Telemetry is only sent from the main process to avoid duplicate events

View File

@@ -55,7 +55,7 @@ tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:

View File

@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==23.2"]
build-backend = "setuptools.build_meta"
[project]
@@ -8,6 +8,7 @@ dynamic = ["version", "dependencies", "optional-dependencies"]
description = "LLM Trainer"
readme = "README.md"
requires-python = ">=3.10"
# license = "Apache-2.0"
[project.scripts]
axolotl = "axolotl.cli.main:main"

View File

@@ -2,3 +2,5 @@ pre-commit
black
mypy
types-requests
quartodoc
jupyter

View File

@@ -1,10 +1,9 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.2
bitsandbytes==0.45.3
triton>=3.0.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.4.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.3
@@ -12,12 +11,12 @@ liger-kernel==0.5.3
packaging==23.2
peft==0.14.0
peft==0.15.0
transformers==4.49.0
tokenizers>=0.21.0
accelerate==1.3.0
datasets==3.2.0
deepspeed==0.16.1
tokenizers>=0.21.1
accelerate==1.5.2
datasets==3.4.1
deepspeed==0.16.4
trl==0.15.1
optimum==1.16.2
@@ -36,6 +35,7 @@ einops
colorama
numba
numpy>=1.24.4,<=2.0.1
# qlora things
evaluate==0.4.1
scipy
@@ -62,7 +62,5 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0
schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.3
# telemetry
posthog>=3.15.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3

View File

@@ -1,6 +1,7 @@
"""
helper script to parse chat datasets into a usable yaml
"""
import click
import yaml
from datasets import load_dataset

View File

@@ -1,4 +1,5 @@
"""Script to output the correct installation command for cut-cross-entropy."""
import importlib.util
import sys
@@ -24,5 +25,5 @@ if cce_spec:
print(
UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
)

View File

@@ -17,11 +17,7 @@ def parse_requirements():
lines = [r.strip() for r in requirements_file.readlines()]
for line in lines:
is_extras = (
"flash-attn" in line
or "flash-attention" in line
or "deepspeed" in line
or "mamba-ssm" in line
or "lion-pytorch" in line
"deepspeed" in line or "mamba-ssm" in line or "lion-pytorch" in line
)
if line.startswith("--extra-index-url"):
# Handle custom index URLs
@@ -39,7 +35,6 @@ def parse_requirements():
"bitsandbytes",
"triton",
"mamba-ssm",
"flash-attn",
"xformers",
"autoawq",
"liger-kernel",
@@ -124,11 +119,10 @@ setup(
],
},
extras_require={
"flash-attn": [
"flash-attn==2.7.4.post1",
],
"flash-attn": ["flash-attn==2.7.4.post1"],
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
"deepspeed": [
"deepspeed==0.16.1",
"deepspeed==0.16.4",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -1,6 +1,7 @@
"""
launch axolotl in supported cloud platforms
"""
from pathlib import Path
from typing import Union

View File

@@ -1,6 +1,7 @@
"""
base class for cloud platforms from cli
"""
from abc import ABC, abstractmethod

View File

@@ -1,6 +1,7 @@
"""
Modal Cloud support from CLI
"""
import copy
import json
import os
@@ -113,7 +114,7 @@ class ModalCloud(Cloud):
[
# Random id for cache busting of branch commits
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch} && git pull",
]
)
@@ -270,6 +271,7 @@ def _preprocess(config_yaml: str, volumes=None):
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/mounts"
@@ -288,6 +290,7 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
def _lm_eval(config_yaml: str, volumes=None):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/mounts"

View File

@@ -14,8 +14,6 @@ import yaml
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager
from axolotl.telemetry.errors import send_errors
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
@@ -29,8 +27,6 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = logging.getLogger(__name__)
TELEMETRY_MANAGER = TelemetryManager.get_instance()
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
"""
@@ -156,7 +152,6 @@ def prepare_plugins(cfg: DictDefault):
plugin_manager.register(plugin_name)
@send_errors
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
"""
Loads the `axolotl` configuration stored at `config`, validates it, and performs
@@ -176,7 +171,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
# Load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
TELEMETRY_MANAGER.send_event(event_type="config-loaded", properties=cfg)
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
@@ -220,6 +214,4 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
return cfg

View File

@@ -17,7 +17,6 @@ from axolotl.cli.args import InferenceCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.telemetry.errors import send_errors
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
@@ -43,7 +42,6 @@ def get_multi_line_input() -> str:
return instruction
@send_errors
def do_inference(
*,
cfg: DictDefault,
@@ -137,7 +135,6 @@ def do_inference(
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
@send_errors
def do_inference_gradio(
*,
cfg: DictDefault,

View File

@@ -1,4 +1,5 @@
"""Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name
import logging
@@ -24,7 +25,7 @@ from axolotl.cli.utils import (
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
from axolotl.utils.schemas.config import AxolotlInputConfig
@click.group()

View File

@@ -12,13 +12,11 @@ from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
@send_errors
def do_merge_lora(*, cfg: DictDefault) -> None:
"""
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config

View File

@@ -27,7 +27,6 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.telemetry.errors import send_errors
LOG = logging.getLogger(__name__)
@@ -121,7 +120,6 @@ def _distributed_checkpoint_to_merged_weights(
return save_path_
@send_errors
def merge_fsdp_weights(
checkpoint_dir: str,
output_path: str,

View File

@@ -18,14 +18,12 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger(__name__)
@send_errors
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
"""
Preprocesses dataset specified in axolotl config.

View File

@@ -1,6 +1,7 @@
"""CLI to run training on a model."""
import logging
import os
from pathlib import Path
from typing import Union
@@ -22,7 +23,7 @@ from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
"""
Trains a `transformers` model by first loading the dataset(s) specified in the
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
@@ -34,23 +35,22 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
"""
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
del model, tokenizer, trainer
plugin_manager = PluginManager.get_instance()
del model
del tokenizer
plugin_manager.post_train_unload(cfg)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
"""
Parses `axolotl` config, CLI args, and calls `do_train`.

View File

@@ -5,7 +5,6 @@ import dataclasses
import hashlib
import json
import logging
import typing
from functools import wraps
from pathlib import Path
from types import NoneType
@@ -24,7 +23,7 @@ configure_logging()
LOG = logging.getLogger(__name__)
def strip_optional_type(field_type: type | typing._SpecialForm | None):
def strip_optional_type(field_type: type | str | None):
"""
Extracts the non-`None` type from an `Optional` / `Union` type.

View File

@@ -10,7 +10,6 @@ from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.telemetry.errors import send_errors
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
@@ -25,8 +24,8 @@ class TrainDatasetMeta:
"""Dataclass with fields for training and validation datasets and metadata."""
train_dataset: Dataset
eval_dataset: Optional[Dataset] = None
total_num_steps: Optional[int] = None
eval_dataset: Dataset | None = None
total_num_steps: int | None = None
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
@@ -45,7 +44,6 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
)
@send_errors
def load_datasets(
*,
cfg: DictDefault,
@@ -105,7 +103,6 @@ def load_datasets(
)
@send_errors
def load_preference_datasets(
*,
cfg: DictDefault,

View File

@@ -1,6 +1,5 @@
"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
import json
import sys

View File

@@ -1,6 +1,7 @@
"""
ChatML transformation functions for MessageContents
"""
from typing import Optional
from ..messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
"""
Llama 3.x chat formatting functions for MessageContents
"""
from typing import Optional
from ..messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
"""
shared functions for format transforms
"""
from axolotl.core.chat.messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
"""
internal message representations of chat messages
"""
import json
from enum import Enum
from typing import Any, Callable, List, Optional, Union

View File

@@ -1,6 +1,7 @@
"""
chat dataset module
"""
import os
from typing import Callable, Optional, Union
@@ -43,7 +44,7 @@ class TokenizedChatDataset(Dataset):
process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment]
)
num_proc = min(64, process_or_cpu_count)
num_proc = min(32, process_or_cpu_count)
features = data.features.keys()
tokenized_data = data.map(
map_fn,

View File

@@ -1,6 +1,7 @@
"""
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
"""
from typing import Any, Mapping, Union

View File

@@ -13,9 +13,7 @@
# limitations under the License.
# pylint: disable=too-many-lines
"""
Builder for the training args and trainer
"""
"""Builder for the training args and trainer"""
import abc
import importlib
@@ -35,9 +33,10 @@ from transformers import (
EarlyStoppingCallback,
TrainerCallback,
)
from transformers.training_args import OptimizerNames
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import (
from axolotl.core.trainers import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlMambaTrainer,
@@ -61,8 +60,6 @@ from axolotl.core.training_args import (
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.telemetry.callbacks import TelemetryCallback
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
@@ -87,19 +84,18 @@ from axolotl.utils.collators import (
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.models import ensure_dtype
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
try:
import torch._dynamo # pylint: disable=ungrouped-imports
except ImportError:
pass
LOG = logging.getLogger("axolotl.core.trainer_builder")
LOG = logging.getLogger(__name__)
class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
"""
"""Base class for trainer builder."""
_train_dataset = None
_eval_dataset = None
@@ -112,9 +108,9 @@ class TrainerBuilderBase(abc.ABC):
self.tokenizer = tokenizer
self.processor = processor
# in case the model supports tagging, add the axolotl tag.
# If the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls
# model.push_to_hub instad of trainer.push_to_hub.
# model.push_to_hub instead of trainer.push_to_hub.
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
@@ -178,8 +174,10 @@ class TrainerBuilderBase(abc.ABC):
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
callbacks.extend(
[
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
]
)
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
@@ -188,10 +186,6 @@ class TrainerBuilderBase(abc.ABC):
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
telemetry_manager = TelemetryManager.get_instance()
if telemetry_manager.enabled:
callbacks.append(TelemetryCallback())
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -231,8 +225,8 @@ class TrainerBuilderBase(abc.ABC):
class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for causal models
and reward modelling using TRL.
Build the HuggingFace training args/trainer for causal models and reward modeling
using TRL.
"""
def get_callbacks(self):
@@ -336,9 +330,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs = {}
if self.cfg.include_tokens_per_second is not None:
training_arguments_kwargs[
"include_tokens_per_second"
] = self.cfg.include_tokens_per_second
training_arguments_kwargs["include_tokens_per_second"] = (
self.cfg.include_tokens_per_second
)
if self.cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True
@@ -355,13 +349,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["seed"] = self.cfg.seed
if self.cfg.gradient_checkpointing:
training_arguments_kwargs[
"gradient_checkpointing"
] = self.cfg.gradient_checkpointing
training_arguments_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)
if self.cfg.gradient_checkpointing_kwargs is not None:
training_arguments_kwargs[
"gradient_checkpointing_kwargs"
] = self.cfg.gradient_checkpointing_kwargs
training_arguments_kwargs["gradient_checkpointing_kwargs"] = (
self.cfg.gradient_checkpointing_kwargs
)
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
@@ -377,9 +371,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
if self.cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs[
"lr_quadratic_warmup"
] = self.cfg.lr_quadratic_warmup
training_arguments_kwargs["lr_quadratic_warmup"] = (
self.cfg.lr_quadratic_warmup
)
if self.cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
@@ -403,28 +397,28 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.dataloader_pin_memory is not None:
training_arguments_kwargs[
"dataloader_pin_memory"
] = self.cfg.dataloader_pin_memory
training_arguments_kwargs["dataloader_pin_memory"] = (
self.cfg.dataloader_pin_memory
)
if self.cfg.dataloader_num_workers is not None:
training_arguments_kwargs[
"dataloader_num_workers"
] = self.cfg.dataloader_num_workers
training_arguments_kwargs["dataloader_num_workers"] = (
self.cfg.dataloader_num_workers
)
if self.cfg.dataloader_prefetch_factor is not None:
training_arguments_kwargs[
"dataloader_prefetch_factor"
] = self.cfg.dataloader_prefetch_factor
training_arguments_kwargs["dataloader_prefetch_factor"] = (
self.cfg.dataloader_prefetch_factor
)
if self.cfg.dataloader_drop_last is not None:
training_arguments_kwargs[
"dataloader_drop_last"
] = self.cfg.dataloader_drop_last
training_arguments_kwargs["dataloader_drop_last"] = (
self.cfg.dataloader_drop_last
)
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
training_arguments_kwargs["dataloader_drop_last"] = True
if self.cfg.remove_unused_columns is not None:
training_arguments_kwargs[
"remove_unused_columns"
] = self.cfg.remove_unused_columns
training_arguments_kwargs["remove_unused_columns"] = (
self.cfg.remove_unused_columns
)
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
# no eval set, so don't eval
@@ -456,9 +450,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model:
training_arguments_kwargs[
"metric_for_best_model"
] = self.cfg.metric_for_best_model
training_arguments_kwargs["metric_for_best_model"] = (
self.cfg.metric_for_best_model
)
if self.cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
@@ -471,13 +465,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_arguments_kwargs[
"torch_compile_backend"
] = self.cfg.torch_compile_backend
training_arguments_kwargs["torch_compile_backend"] = (
self.cfg.torch_compile_backend
)
if self.cfg.torch_compile_mode:
training_arguments_kwargs[
"torch_compile_mode"
] = self.cfg.torch_compile_mode
training_arguments_kwargs["torch_compile_mode"] = (
self.cfg.torch_compile_mode
)
# DDP Config
if self.cfg.ddp_timeout:
@@ -486,32 +480,32 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
if self.cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs[
"ddp_broadcast_buffers"
] = self.cfg.ddp_broadcast_buffers
training_arguments_kwargs["ddp_broadcast_buffers"] = (
self.cfg.ddp_broadcast_buffers
)
# these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = (
total_num_steps if self.cfg.max_steps else -1
)
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs[
"per_device_train_batch_size"
] = self.cfg.micro_batch_size
training_arguments_kwargs["per_device_train_batch_size"] = (
self.cfg.micro_batch_size
)
if self.cfg.eval_batch_size:
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
training_arguments_kwargs["per_device_eval_batch_size"] = (
self.cfg.eval_batch_size
)
if self.cfg.auto_find_batch_size is not None:
training_arguments_kwargs[
"auto_find_batch_size"
] = self.cfg.auto_find_batch_size
training_arguments_kwargs[
"gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
training_arguments_kwargs[
"eval_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
training_arguments_kwargs["auto_find_batch_size"] = (
self.cfg.auto_find_batch_size
)
training_arguments_kwargs["gradient_accumulation_steps"] = (
self.cfg.gradient_accumulation_steps
)
training_arguments_kwargs["eval_accumulation_steps"] = (
self.cfg.gradient_accumulation_steps
)
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
@@ -555,34 +549,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
else:
training_arguments_kwargs["run_name"] = None
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[
"alternate_lr_scheduler_type"
] = self.cfg.lr_scheduler
training_arguments_kwargs["alternate_lr_scheduler_type"] = (
self.cfg.lr_scheduler
)
else:
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
@@ -591,9 +563,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
training_arguments_kwargs[
"cosine_constant_lr_ratio"
] = self.cfg.cosine_constant_lr_ratio
training_arguments_kwargs["cosine_constant_lr_ratio"] = (
self.cfg.cosine_constant_lr_ratio
)
training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
)
@@ -606,40 +578,40 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.eval_sample_packing
)
if self.cfg.sample_packing_bin_size is not None:
training_arguments_kwargs[
"sample_packing_bin_size"
] = self.cfg.sample_packing_bin_size
training_arguments_kwargs["sample_packing_bin_size"] = (
self.cfg.sample_packing_bin_size
)
if self.cfg.sample_packing_group_size is not None:
training_arguments_kwargs[
"sample_packing_group_size"
] = self.cfg.sample_packing_group_size
training_arguments_kwargs["sample_packing_group_size"] = (
self.cfg.sample_packing_group_size
)
if self.cfg.sample_packing_eff_est:
training_arguments_kwargs[
"sample_packing_efficiency"
] = self.cfg.sample_packing_eff_est
training_arguments_kwargs["sample_packing_efficiency"] = (
self.cfg.sample_packing_eff_est
)
if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs[
"relora_warmup_steps"
] = self.cfg.relora_warmup_steps
training_arguments_kwargs["relora_warmup_steps"] = (
self.cfg.relora_warmup_steps
)
if self.cfg.relora_anneal_steps:
training_arguments_kwargs[
"relora_anneal_steps"
] = self.cfg.relora_anneal_steps
training_arguments_kwargs["relora_anneal_steps"] = (
self.cfg.relora_anneal_steps
)
if self.cfg.relora_prune_ratio:
training_arguments_kwargs[
"relora_prune_ratio"
] = self.cfg.relora_prune_ratio
training_arguments_kwargs["relora_prune_ratio"] = (
self.cfg.relora_prune_ratio
)
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
training_arguments_kwargs[
"lisa_step_interval"
] = self.cfg.lisa_step_interval
training_arguments_kwargs[
"lisa_layers_attribute"
] = self.cfg.lisa_layers_attribute
training_arguments_kwargs["lisa_step_interval"] = (
self.cfg.lisa_step_interval
)
training_arguments_kwargs["lisa_layers_attribute"] = (
self.cfg.lisa_layers_attribute
)
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
@@ -653,59 +625,127 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[
"neftune_noise_alpha"
] = self.cfg.neftune_noise_alpha
training_arguments_kwargs["neftune_noise_alpha"] = (
self.cfg.neftune_noise_alpha
)
trainer_kwargs = {}
if self.cfg.reward_model:
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
# pylint: disable=duplicate-code
if self.cfg.optimizer in [
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
# Handle custom optimizer
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
if self.cfg.optimizer in custom_supported_optimizers:
# Common optimizer kwargs
optimizer_kwargs = {
"lr": training_arguments_kwargs.get("learning_rate"),
"weight_decay": training_arguments_kwargs.get("weight_decay"),
}
if self.cfg.optimizer == "lion_pytorch":
from lion_pytorch import Lion
# Adam-specific kwargs
adam_kwargs = {}
if training_arguments_kwargs.get(
"adam_beta1"
) and training_arguments_kwargs.get("adam_beta2"):
adam_kwargs["betas"] = (
training_arguments_kwargs.get("adam_beta1"),
training_arguments_kwargs.get("adam_beta2"),
)
if training_arguments_kwargs.get("adam_epsilon"):
adam_kwargs["eps"] = training_arguments_kwargs.get("adam_epsilon")
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
if "weight_decay" in training_arguments_kwargs:
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
if (
"adam_beta1" in training_arguments_kwargs
and "adam_beta2" in training_arguments_kwargs
):
lion_kwargs["betas"] = (
training_arguments_kwargs["adam_beta1"],
training_arguments_kwargs["adam_beta2"],
if self.cfg.optimizer == "muon":
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
MuonOptimizerFactory,
)
trainer_kwargs["optimizers"] = (
Lion(params=self.model.parameters(), **lion_kwargs),
None,
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
optimizer_kwargs["foreach"] = False
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_4bit":
# TODO remove 20250401
from torchao.prototype.low_bit_optim import AdamW4bit
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
)
elif self.cfg.optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
optimizer_cls = AdamW8bit
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
optimizer_cls = AdamWFp8
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
optimizer_cls = ADOPT
adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs)
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optimizer_kwargs.update(self.cfg.optim_args)
else:
# Parse string format "key1=value1,key2=value2"
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optimizer_kwargs[key] = value
trainer_kwargs["optimizer_cls_and_kwargs"] = (
optimizer_cls,
optimizer_kwargs,
)
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
else:
# Use transformers' optimizer
training_arguments_kwargs["optim"] = self.cfg.optimizer
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
if self.cfg.optim_target_modules:
training_arguments_kwargs["optim_target_modules"] = (
self.cfg.optim_target_modules
)
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs["loraplus_lr_embedding"] = (
self.cfg.loraplus_lr_embedding
)
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.accelerator_config:
training_arguments_kwargs[
"accelerator_config"
] = self.cfg.accelerator_config
training_arguments_kwargs["accelerator_config"] = (
self.cfg.accelerator_config
)
if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
@@ -714,13 +754,17 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None:
training_arguments_kwargs[
"kd_zscore_base_temp"
] = self.cfg.kd_zscore_base_temp
training_arguments_kwargs["kd_zscore_base_temp"] = (
self.cfg.kd_zscore_base_temp
)
if self.cfg.kd_top_k_before_softmax is not None:
training_arguments_kwargs[
"kd_top_k_before_softmax"
] = self.cfg.kd_top_k_before_softmax
training_arguments_kwargs["kd_top_k_before_softmax"] = (
self.cfg.kd_top_k_before_softmax
)
training_arguments_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
@@ -805,9 +849,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
if self.cfg.micro_batch_size > 1:
if (
self.cfg.pretraining_sample_concatenation is False
or self.cfg.micro_batch_size > 1
):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None
@@ -835,9 +880,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if "max_length" in kwargs:
kwargs.pop("max_length")
elif use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or (
self.cfg.model_config_type in ["llama"]
and self.cfg.flash_attention is not True
):
@@ -868,6 +911,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt"
if issubclass(collator, DataCollatorForSeq2Seq):
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
return collator(
*collator_args,
@@ -876,9 +921,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
class HFRLTrainerBuilder(TrainerBuilderBase):
"""
Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
"""
"""Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
def get_callbacks(self):
callbacks = super().get_callbacks()
@@ -932,32 +975,32 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
if self.cfg.remove_unused_columns is not None:
training_args_kwargs[
"remove_unused_columns"
] = self.cfg.remove_unused_columns
training_args_kwargs["remove_unused_columns"] = (
self.cfg.remove_unused_columns
)
else:
training_args_kwargs["remove_unused_columns"] = False
if self.cfg.dataloader_pin_memory is not None:
training_args_kwargs[
"dataloader_pin_memory"
] = self.cfg.dataloader_pin_memory
training_args_kwargs["dataloader_pin_memory"] = (
self.cfg.dataloader_pin_memory
)
if self.cfg.dataloader_num_workers is not None:
training_args_kwargs[
"dataloader_num_workers"
] = self.cfg.dataloader_num_workers
training_args_kwargs["dataloader_num_workers"] = (
self.cfg.dataloader_num_workers
)
if self.cfg.dataloader_prefetch_factor is not None:
training_args_kwargs[
"dataloader_prefetch_factor"
] = self.cfg.dataloader_prefetch_factor
training_args_kwargs["dataloader_prefetch_factor"] = (
self.cfg.dataloader_prefetch_factor
)
if self.cfg.gradient_checkpointing:
training_args_kwargs[
"gradient_checkpointing"
] = self.cfg.gradient_checkpointing
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)
if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs[
"gradient_checkpointing_kwargs"
] = self.cfg.gradient_checkpointing_kwargs
training_args_kwargs["gradient_checkpointing_kwargs"] = (
self.cfg.gradient_checkpointing_kwargs
)
else:
training_args_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
@@ -1031,9 +1074,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs[
"use_logits_to_keep"
] = self.cfg.dpo_use_logits_to_keep
training_args_kwargs["use_logits_to_keep"] = (
self.cfg.dpo_use_logits_to_keep
)
for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs:
@@ -1068,9 +1111,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.adapter and self.peft_config:
dpo_trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
dpo_trainer_kwargs[
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
dpo_trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs
)
if self.cfg.rl == "grpo":
trainer_cls = GRPOStrategy.get_trainer_class()
trainer_cls_args = [self.model]

View File

@@ -0,0 +1,18 @@
"""Init for axolotl.core.trainers"""
# pylint: disable=unused-import
# flake8: noqa
from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer
from .relora import ReLoRATrainer
from .trl import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlORPOTrainer,
AxolotlPRMTrainer,
AxolotlRewardTrainer,
TRLPPOTrainer,
)

View File

@@ -1,163 +1,47 @@
"""
module for customized trainers
"""
"""Module for customized trainers"""
# pylint: disable=too-many-lines
from __future__ import annotations
# pylint: disable=too-many-lines
import logging
import os
from collections import defaultdict
from functools import wraps
from typing import Dict, Literal, Optional
from typing import Any, Literal
import datasets
import torch
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from torch import nn
from torch.utils.data import (
BatchSampler,
DataLoader,
RandomSampler,
Sampler,
SequentialSampler,
)
from transformers import Trainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.monkeypatch.relora import ReLoRAScheduler
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
from axolotl.core.trainers.mixins import (
OptimizerMixin,
SchedulerMixin,
SequenceParallelMixin,
)
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
LOG = logging.getLogger("axolotl.core.trainer_builder")
LOG = logging.getLogger(__name__)
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names
return kwargs
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
if isinstance(dataset_tags, str):
dataset_tags = [dataset_tags]
if (dataset_tags is not None) and (kwargs is not None):
if "dataset_tags" not in kwargs:
kwargs["dataset_tags"] = dataset_tags
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
kwargs["dataset_tags"].extend(dataset_tags)
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
dataset_tags.append(kwargs["dataset_tags"])
kwargs["dataset_tags"] = dataset_tags
return kwargs
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {}
if "pct_start" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["pct_start"] = pct_start
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["anneal_strategy"] = "cos"
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer=optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler
class AxolotlTrainer(SchedulerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer):
"""Extend the base Trainer for axolotl helpers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
@@ -174,12 +58,18 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
# Initialize sequence parallelism if enabled
if self.args.sequence_parallel_degree > 1:
self._setup_sequence_parallel()
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
@@ -192,316 +82,247 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
decay_parameters = self.get_decay_parameter_names(opt_model)
def _create_multipack_sampler(
self, base_sampler: Sampler, dataset: Dataset
) -> MultipackBatchSampler:
"""
Helper method to create a `MultipackBatchSampler` for multipacking sequences
for training.
Args:
base_sampler: Sampler to wrap with `MultipackBatchSampler`.
dataset: Dataset to sample from.
Returns:
Multipack (sample packing) batch sampler.
"""
if self.args.multipack_real_batches:
batch_size = self.args.per_device_train_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
train_batch_size = (
self.state.train_batch_size or self.args.per_device_train_batch_size
)
batch_max_len = train_batch_size * self.args.max_seq_length
return MultipackBatchSampler(
base_sampler,
lengths=get_dataset_lengths(dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
drop_last=True,
)
def _get_train_sampler(self) -> Sampler | None:
"""
Helper method to get the sampler for training. Handles cases for sequence
parallelism, sample packing, and curriculum sampling (sequential).
Returns:
If the dataset is non-empty, a sampler is returned, the type of which
depends on the passed training args.
"""
use_sample_packing = self.args.sample_packing and not self.args.pretraining
# Determine the base sampler first
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_train_sampler(self.train_dataset)
elif self.args.curriculum_sampling:
base_sampler = SequentialSampler(self.train_dataset)
elif use_sample_packing:
base_sampler = RandomSampler(self.train_dataset)
else:
# Default to parent class implementation for standard random sampling
return super()._get_train_sampler()
# Apply multipack wrapper if needed
if use_sample_packing:
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=self.train_dataset,
)
return base_sampler
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
"""
Helper method to get the sampler for evaluation. Handles sequence parallelism
and sample packing cases.
Returns:
If the dataset is non-empty, a sampler is returned, the type of which
depends on the passed training args.
"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Multipacking enabled if training is enabled and eval is not explicitly disabled
use_multipack = (
self.args.sample_packing and self.args.eval_sample_packing is not False
)
# Determine the base sampler
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_eval_sampler(eval_dataset)
elif use_multipack:
base_sampler = SequentialSampler(eval_dataset)
else:
return super()._get_eval_sampler(eval_dataset)
# Apply multipack wrapper if needed
if use_multipack:
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=eval_dataset,
)
return base_sampler
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
"""Create common dataloader parameters for train or eval."""
batch_size = custom_batch_size or (
self.args.eval_batch_size if is_eval else self._train_batch_size
)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
"batch_size": batch_size,
"collate_fn": self.data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
lr_groups_lookup = {}
lr_groups_learning_rates = {}
if self.args.lr_groups:
for lr_group in self.args.lr_groups:
group_name = lr_group["name"]
group_modules = lr_group["modules"]
for module in group_modules:
lr_groups_lookup[module] = group_name
lr_groups_learning_rates[group_name] = lr_group["lr"]
params[f"to_weight_decay_{group_name}"] = {}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
lr_group_modules = [
group_modules
for group_modules in lr_groups_lookup
if group_modules in name
]
if lr_groups_lookup and any(lr_group_modules):
lr_group_module = lr_group_modules[0]
group_name = lr_groups_lookup[lr_group_module]
params[f"to_weight_decay_{group_name}"][name] = param
else:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
for group_name, group_lr in lr_groups_learning_rates.items():
if params[f"to_weight_decay_{group_name}"]:
optimizer_grouped_parameters.append(
{
"params": list(
params[f"to_weight_decay_{group_name}"].values()
),
"weight_decay": self.args.weight_decay,
"lr": group_lr,
}
)
# Add persistent workers only for training
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
params["persistent_workers"] = self.args.dataloader_persistent_workers
return optimizer_grouped_parameters
# Add prefetch factor if specified
if self.args.dataloader_prefetch_factor:
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.args.alternate_optimizer
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"adopt_adamw",
]
):
return super().create_optimizer()
return params
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs
)
def _prepare_dataloader(
self, dataset, sampler, is_eval=False, custom_batch_size=None
):
"""Prepare a dataloader with the given dataset and sampler."""
# Get base parameters
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", 1e-6
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
elif (
self.args.embedding_lr_scale is not None
or self.args.embedding_lr is not None
or self.args.lr_groups is not None
):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW(
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)
elif self.args.alternate_optimizer == "ao_adamw_4bit":
from torchao.prototype.low_bit_optim import AdamW4bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
ADOPT(
optimizer_grouped_parameters,
decouple=True,
**optimizer_kwargs,
)
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches:
batch_size = self.args.per_device_train_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
train_batch_size = (
self.state.train_batch_size or self.args.per_device_train_batch_size
)
batch_max_len = train_batch_size * self.args.max_seq_length
if self.args.curriculum_sampling:
sampler = SequentialSampler(self.train_dataset)
else:
sampler = RandomSampler(self.train_dataset)
return MultipackBatchSampler(
sampler,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
drop_last=True,
)
if self.args.curriculum_sampling:
return SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
if self.args.multipack_real_batches:
batch_size = self.args.per_device_eval_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
batch_max_len = (
self.args.per_device_eval_batch_size * self.args.max_seq_length
)
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
lengths=get_dataset_lengths(self.eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
drop_last=True,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing and not self.args.pretraining:
train_dataset = self.train_dataset
if "length" in train_dataset.features.keys():
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
sampler = self._get_train_sampler()
# Add sampler configuration
if not isinstance(dataset, torch.utils.data.IterableDataset):
if isinstance(sampler, BatchSampler):
# batch_size and batch_sampler are mutually exclusive
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
if not is_eval:
dataloader_params["worker_init_fn"] = seed_worker
# Create the dataloader
dataloader = DataLoader(dataset, **dataloader_params)
if self.args.sample_packing and (
(not is_eval and not self.args.pretraining)
or (is_eval and self.args.eval_sample_packing is not False)
):
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
return super().get_train_dataloader()
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
# Return unprepared dataloader if using sequence parallelism
if self.args.sequence_parallel_degree > 1:
return dataloader
# Otherwise prepare with accelerator
return self.accelerator.prepare_data_loader(dataloader)
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
train_dataset = self.train_dataset
data_collator = self.data_collator # type: ignore
# Handle dataset preprocessing
if isinstance(train_dataset, datasets.Dataset):
if self.args.sample_packing and not self.args.pretraining:
train_dataset = train_dataset.remove_columns(["length"])
if not self.args.sample_packing or self.args.pretraining:
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
data_collator,
description="training",
)
# Get sampler and create dataloader
sampler = self._get_train_sampler()
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
"""Get dataloader for evaluation"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Handle special case: sample packing is enabled but eval_sample_packing is False
if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
)
if eval_dataset:
if "length" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns(["length"])
dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator
)
return dataloader
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
# Handle sample packing or sequence parallelism
if (
self.args.sample_packing
and self.args.eval_sample_packing is not False
or self.args.sequence_parallel_degree > 1
):
# Get appropriate data collator
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
if hasattr(self, "eval_data_collator") and self.eval_data_collator
else self.data_collator
)
if "length" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns(["length"])
# Handle dataset preprocessing for SP
if self.args.sequence_parallel_degree > 1:
if isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(
eval_dataset, description="evaluation"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
self.data_collator, description="evaluation"
)
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
batch_size = (
self.args.eval_batch_size
if self.args.sample_packing
else self.args.per_device_eval_batch_size
)
sampler = self._get_eval_sampler(eval_dataset)
dataloader = self._prepare_dataloader(
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
)
eval_sampler = self._get_eval_sampler(eval_dataset)
eval_dataset = eval_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = eval_sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params)
)
return dataloader
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
) -> torch.utils.data.Sampler | None:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
@@ -526,6 +347,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
@override
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
@@ -542,6 +364,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return super().compute_loss(
model,
inputs,
@@ -716,10 +539,10 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
kwargs = sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
@@ -736,15 +559,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
return res
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
start_time (`Optional[float]`):
The start of training.
logs: The values to log.
start_time: The start of training.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
@@ -756,7 +577,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
return super().log(logs, start_time)
def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
) -> None:
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
@@ -769,110 +590,26 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs)
class AxolotlMambaTrainer(AxolotlTrainer):
"""
Mamba specific trainer to handle loss calculation
"""
tag_names = ["axolotl", "mamba"]
def compute_loss(
def training_step(
self,
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch: int | None = None,
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
Args:
model: Model to perform training step for.
inputs: Dictionary mapping.
"""
# Set up sequence parallelism for this step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
# Proceed with normal training step
loss = super().training_step(model, inputs, num_items_in_batch)
return lm_loss
class ReLoRATrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
tag_names = ["axolotl", "relora"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
anneal_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"]
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"]
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""
tag_names = ["axolotl", "prm"]
return loss

View File

@@ -1,6 +1,7 @@
"""
DPO Specific Strategy for training
"""
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer

View File

@@ -1,6 +1,7 @@
"""
Axolotl specific DPO args
"""
from dataclasses import dataclass
from trl import DPOConfig

View File

@@ -1,6 +1,7 @@
"""
DPO trainer for axolotl
"""
import gc
from functools import wraps
from typing import Any, Dict, Union
@@ -12,10 +13,10 @@ from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.core.trainers.base import (
SchedulerMixin,
_sanitize_kwargs_for_ds_tagging,
_sanitize_kwargs_for_tagging,
from axolotl.core.trainers.mixins import SchedulerMixin
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
if is_sagemaker_mp_enabled():
@@ -73,10 +74,10 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
kwargs = sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)

View File

@@ -9,6 +9,7 @@ import logging
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
from axolotl.utils.schemas.trl import TRLConfig
LOG = logging.getLogger("axolotl")
@@ -31,30 +32,44 @@ class GRPOStrategy:
@classmethod
def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {}
if cfg.trl and cfg.trl.use_vllm:
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
if cfg.trl and cfg.trl.vllm_device:
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
else:
grpo_args_kwargs["vllm_device"] = "auto"
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = cfg.trl.vllm_gpu_memory_utilization
if cfg.trl and cfg.trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
if cfg.trl and cfg.trl.num_generations:
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
if cfg.trl and cfg.trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
grpo_args_kwargs[
"ref_model_mixup_alpha"
] = cfg.trl.ref_model_mixup_alpha
if cfg.trl and cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
if not hasattr(cfg, "trl") or not cfg.trl:
return grpo_args_kwargs
trl: TRLConfig = cfg.trl # type: ignore
if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_device"] = (
trl.vllm_device if trl.vllm_device else "auto"
)
if trl.vllm_gpu_memory_utilization:
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
trl.vllm_gpu_memory_utilization
)
if trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
if trl.num_generations:
grpo_args_kwargs["num_generations"] = trl.num_generations
if trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model
if trl.ref_model_mixup_alpha:
grpo_args_kwargs["ref_model_mixup_alpha"] = trl.ref_model_mixup_alpha
if trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
grpo_args_kwargs["log_completions"] = trl.log_completions
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights
return grpo_args_kwargs
@classmethod
@@ -71,9 +86,9 @@ class GRPOStrategy:
def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs[
"reward_processing_classes"
] = cfg.trl.reward_processing_classes
trainer_kwargs["reward_processing_classes"] = (
cfg.trl.reward_processing_classes
)
return trainer_kwargs
@classmethod

View File

@@ -1,6 +1,7 @@
"""
Axolotl Specific Training Args
"""
from dataclasses import dataclass
from trl import GRPOConfig

View File

@@ -1,6 +1,7 @@
"""
Axolotl GRPO trainer
"""
from accelerate.utils import is_peft_model
from accelerate.utils.other import is_compiled_module
from transformers import PreTrainedModel

View File

@@ -0,0 +1,32 @@
"""Module for mamba trainer"""
import torch
from axolotl.core.trainers.base import AxolotlTrainer
class AxolotlMambaTrainer(AxolotlTrainer):
"""Mamba specific trainer to handle loss calculation"""
tag_names = ["axolotl", "mamba"]
def compute_loss(
self,
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
return lm_loss

View File

@@ -0,0 +1,8 @@
"""Init for axolotl.core.trainers.mixins"""
# pylint: disable=unused-import
# flake8: noqa
from .optimizer import OptimizerMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin

View File

@@ -0,0 +1,201 @@
"""Module for Axolotl trainer optimizer mixin"""
import logging
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from transformers.trainer import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from axolotl.integrations.base import BaseOptimizerFactory
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
LOG = logging.getLogger(__name__)
class OptimizerMixin(Trainer):
"""Mixin class for shared handling of building custom optimizers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_optimizer_grouped_parameters(
self, opt_model, optimizer_kwargs
) -> list[dict]:
decay_parameters = self.get_decay_parameter_names(opt_model)
params: dict = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
lr_groups_lookup = {}
lr_groups_learning_rates = {}
if self.args.lr_groups:
for lr_group in self.args.lr_groups:
group_name = lr_group["name"]
group_modules = lr_group["modules"]
for module in group_modules:
lr_groups_lookup[module] = group_name
lr_groups_learning_rates[group_name] = lr_group["lr"]
params[f"to_weight_decay_{group_name}"] = {}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
lr_group_modules = [
group_modules
for group_modules in lr_groups_lookup
if group_modules in name
]
if lr_groups_lookup and any(lr_group_modules):
lr_group_module = lr_group_modules[0]
group_name = lr_groups_lookup[lr_group_module]
params[f"to_weight_decay_{group_name}"][name] = param
else:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
for group_name, group_lr in lr_groups_learning_rates.items():
if params[f"to_weight_decay_{group_name}"]:
optimizer_grouped_parameters.append(
{
"params": list(
params[f"to_weight_decay_{group_name}"].values()
),
"weight_decay": self.args.weight_decay,
"lr": group_lr,
}
)
return optimizer_grouped_parameters
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None
):
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if (
not self.optimizer
and self.optimizer_cls_and_kwargs is not None
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
):
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
self.optimizer = optimizer_factory_cls()(
opt_model, self.args, **optimizer_kwargs
)
if not self.optimizer:
if self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
self.args, opt_model
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs
)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", 1e-6
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
else:
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop(
"optimizer_dict"
)
self.optimizer = optimizer_cls(
optimizer_grouped_parameters, **optimizer_kwargs
)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer

View File

@@ -0,0 +1,113 @@
"""Module for Axolotl trainer scheduler mixin"""
import logging
import torch
from torch.optim.lr_scheduler import OneCycleLR
from transformers.trainer import Trainer
from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
)
LOG = logging.getLogger(__name__)
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {}
if "pct_start" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["pct_start"] = pct_start
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["anneal_strategy"] = "cos"
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif self.args.alternate_lr_scheduler_type == "rex":
if use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = RexLR(
optimizer=optimizer,
max_lr=self.args.learning_rate,
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
total_steps=num_training_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer=optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler

View File

@@ -0,0 +1,131 @@
"""Module for Axolotl trainer sequence parallelism mixin"""
import logging
from typing import Any
import torch
import torch.distributed as dist
import torch.nn.functional as F
from datasets import Dataset
from torch.utils.data import DistributedSampler, Sampler
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
LOG = logging.getLogger(__name__)
try:
from ring_flash_attn import update_ring_flash_attn_params
except ImportError:
# We pass silently here, but raise an ImportError in our Axolotl config validation
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
pass
class SequenceParallelMixin:
"""
Mixin class for sequence parallelism support in trainers.
This mixin provides functionality for handling sequence parallelism,
including creating appropriate samplers, managing data partitioning,
and updating ring flash attention parameters during training.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def _setup_sequence_parallel(self):
"""Set up sequence parallelism environment."""
self.ring_attn_group = get_ring_attn_group()
def _create_sequence_parallel_sampler(
self,
dataset: Dataset,
shuffle: bool = True,
is_eval: bool = False,
) -> DistributedSampler:
"""
Helper method to create sampler for sequence parallelism (SP).
We create a distributed sampler with rank equal to the SP group ID, which
means that all ranks in the SP group receive the same sample / set of samples
per training step. We also set the number of replicas equal to the number of
SP groups, which is a bit of a hack / unintended use, but works!
Args:
dataset: Dataset to sample from.
shuffle: Whether to shuffle the dataset.
is_eval: Whether we are creating a sampler for evaluation or training.
Returns:
Distributed sampler.
"""
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
return DistributedSampler(
dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed if shuffle else None,
shuffle=shuffle,
drop_last=not is_eval,
)
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
"""
Get a training sampler configured for sequence parallelism.
Args:
dataset: The training dataset
Returns:
Configured sequence parallel sampler.
"""
return self._create_sequence_parallel_sampler(
dataset,
shuffle=not self.args.curriculum_sampling,
)
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
"""
Get an evaluation sampler configured for sequence parallelism.
Args:
eval_dataset: The evaluation dataset.
Returns:
Configured sequence parallel sampler.
"""
return self._create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)
def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]):
"""
Calculate the cu_seqlens for the current forward pass and pass the value to
the substituted ring_flash_attn. This is accomplished by using the passed
`input_ids`.
Args:
inputs: Current batch of inputs.
"""
# At this point, inputs should already be partitioned by the sequence
# parallel data collator
batch_size = inputs["input_ids"].shape[0]
seq_len = inputs["input_ids"].shape[1]
packed_seq_lens = [seq_len] * batch_size
# Calculate the full sequence length across all GPUs in this SP group
total_seq_len = seq_len * self.args.sequence_parallel_degree
cu_seqlens = torch.cumsum(
torch.tensor(
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
)
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)

View File

@@ -0,0 +1,43 @@
"""Module for ReLoRA trainer"""
import torch
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.monkeypatch.relora import ReLoRAScheduler
class ReLoRATrainer(AxolotlTrainer):
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
tag_names = ["axolotl", "relora"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: torch.optim.Optimizer | None = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
anneal_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler

View File

@@ -1,15 +1,23 @@
"""
module for TRL PPO training
"""
"""Module for TRL PPO trainer"""
import torch
from tqdm import tqdm
from trl import PPOTrainer
from trl import (
CPOTrainer,
KTOTrainer,
ORPOTrainer,
PPOTrainer,
PRMTrainer,
RewardTrainer,
)
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
class TRLPPOTrainer(PPOTrainer):
"""
wrapper for ppo trainer to handle customizations
"""
"""Wrapper for TRL PPO trainer to handle customizations"""
tag_names = ["axolotl", "ppo"]
def train(
self,
@@ -30,9 +38,7 @@ class TRLPPOTrainer(PPOTrainer):
"batch_size": 16,
}
for epoch, batch in tqdm( # pylint: disable=unused-variable
enumerate(self.dataloader)
):
for _, batch in tqdm(enumerate(self.dataloader)):
query_tensors = batch["input_ids"]
# generate model response
@@ -64,3 +70,43 @@ class TRLPPOTrainer(PPOTrainer):
rewards,
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
)
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"]
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"]
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""
tag_names = ["axolotl", "prm"]

View File

@@ -0,0 +1,33 @@
"""Utils for Axolotl trainers"""
def sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names
return kwargs
def sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
if isinstance(dataset_tags, str):
dataset_tags = [dataset_tags]
if (dataset_tags is not None) and (kwargs is not None):
if "dataset_tags" not in kwargs:
kwargs["dataset_tags"] = dataset_tags
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
kwargs["dataset_tags"].extend(dataset_tags)
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
dataset_tags.append(kwargs["dataset_tags"])
kwargs["dataset_tags"] = dataset_tags
return kwargs

View File

@@ -1,6 +1,7 @@
"""
extra axolotl specific training args
"""
from dataclasses import dataclass, field
from typing import Optional
@@ -206,14 +207,19 @@ class AxolotlTrainingMixins:
},
)
sequence_parallel_degree: Optional[int] = field(
default=1,
metadata={"help": "The number of workers to use in sequence parallelism"},
)
@dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
so it can't be used as a mixin.
This code is duplicated due to HF TrainingArguments not setting output_dir with a
default value so it can't be used as a mixin.
"""

View File

@@ -8,9 +8,10 @@ from typing import Dict, Optional
import torch
from accelerate.logging import get_logger
from datasets import Dataset
from transformers.trainer import Trainer
from axolotl.logging_config import configure_logging
from axolotl.telemetry.errors import send_errors
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault
@@ -26,18 +27,18 @@ LOG = get_logger("axolotl.evaluate")
def evaluate_dataset(
trainer, dataset, dataset_type: str, flash_optimum: bool = False
trainer: Trainer, dataset: Dataset, dataset_type: str, flash_optimum: bool = False
) -> Optional[Dict[str, float]]:
"""Helper function to evaluate a single dataset safely.
"""Helper function to evaluate a single dataset.
Args:
trainer: The trainer instance
dataset: Dataset to evaluate
dataset_type: Type of dataset ('train' or 'eval')
flash_optimum: Whether to use flash optimum
trainer: The trainer instance.
dataset: Dataset to evaluate.
dataset_type: Type of dataset ('train' or 'eval').
flash_optimum: Whether to use flash optimum.
Returns:
Dictionary of metrics or None if dataset is None
Dictionary of metrics or None if dataset is None.
"""
if dataset is None:
return None
@@ -62,20 +63,16 @@ def evaluate_dataset(
return metrics
@send_errors
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
"""
Evaluate a model on training and validation datasets
Evaluate a model on training and validation datasets.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
dataset_meta: Dataset metadata containing training and evaluation datasets.
Returns:
Tuple containing:
- The model (either PeftModel or PreTrainedModel)
- The tokenizer
- Dictionary of evaluation metrics
Dictionary mapping metric names to their values.
"""
# pylint: disable=duplicate-code
# Enable expandable segments for cuda allocation to improve VRAM usage

View File

@@ -23,6 +23,8 @@ import importlib
import logging
from typing import OrderedDict
import torch
class BasePlugin:
"""
@@ -469,3 +471,14 @@ class PluginManager:
"""
for plugin in self.plugins.values():
plugin.post_train_unload(cfg)
class BaseOptimizerFactory:
"""
Base class for factories to create custom optimizers
"""
def __call__(
self, opt_model, training_args, **optimizer_kwargs
) -> "torch.optim.Optimizer":
pass

View File

@@ -11,19 +11,17 @@
# the License.
"""
module to handle merging the plugins' input arguments with the base configurations.
Module to handle merging the plugins' input arguments with the base configurations.
this was moved here to prevent circular imports
This was moved here to prevent circular imports.
"""
from typing import Any, Dict, List
from axolotl.utils.config.models.input.v0_4_1 import (
from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
)
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlInputConfig as AxolotlInputConfigBase,
)
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
def merge_input_args():

View File

@@ -4,6 +4,22 @@ Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy o
See https://github.com/apple/ml-cross-entropy
## Requirements
- PyTorch 2.4.0 or higher
## Installation
Run the following command to install `cut_cross_entropy[transformers]` if you don't have it already.
```bash
# if you are in dev environment
python scripts/cutcrossentropy_install.py | sh
# if you are not in dev environment
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"
```
## Usage
```yaml

View File

@@ -33,7 +33,7 @@ LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
_CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers]==24.11.4"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"`'
)

View File

@@ -1,6 +1,7 @@
"""
Grokfast plugin for Axolotl
"""
import logging
from transformers.trainer_callback import TrainerCallback

View File

@@ -1,6 +1,7 @@
"""
config args for grokfast plugin
"""
from typing import Optional
from pydantic import BaseModel

View File

@@ -26,12 +26,12 @@ class KDArgs(BaseModel):
"""
kd_trainer: Optional[bool] = None # whether to use KD trainer
kd_ce_alpha: Optional[
float
] = None # loss coefficient for cross-entropy loss during KD
kd_ce_alpha: Optional[float] = (
None # loss coefficient for cross-entropy loss during KD
)
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[
bool
] = None # whether to sample top k before softmax during KD
kd_top_k_before_softmax: Optional[bool] = (
None # whether to sample top k before softmax during KD
)

View File

@@ -55,9 +55,9 @@ class LigerPlugin(BasePlugin):
if "cross_entropy" in liger_fn_sig.parameters:
kwargs["cross_entropy"] = cfg.liger_cross_entropy
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
kwargs[
"fused_linear_cross_entropy"
] = cfg.liger_fused_linear_cross_entropy
kwargs["fused_linear_cross_entropy"] = (
cfg.liger_fused_linear_cross_entropy
)
if "rms_norm" in liger_fn_sig.parameters:
kwargs["rms_norm"] = cfg.liger_rms_norm
if "layer_norm" in liger_fn_sig.parameters:

View File

@@ -1,6 +1,7 @@
"""
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
"""
# pylint: disable=duplicate-code
from typing import List, Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
"""
Jamba model with LigerFusedLinearCrossEntropyLoss
"""
# pylint: disable=duplicate-code
from typing import Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
"""
Module for the Plugin for LM Eval Harness
"""
import subprocess # nosec
from axolotl.integrations.base import BasePlugin

View File

@@ -1,6 +1,7 @@
"""
Module for handling lm eval harness input arguments.
"""
from typing import List, Optional
from pydantic import BaseModel

View File

@@ -1,6 +1,7 @@
"""
axolotl CLI for running lm_eval tasks
"""
import subprocess # nosec
from collections import defaultdict
from datetime import datetime

View File

@@ -17,7 +17,7 @@ Module for handling Spectrum input arguments.
"""
from typing import Optional
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
class SpectrumArgs(BaseModel):
@@ -27,3 +27,20 @@ class SpectrumArgs(BaseModel):
spectrum_top_fraction: Optional[float] = 0.5
spectrum_model_name: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_fsdp_use_orig_params(cls, data):
if (
data.get("fsdp")
and data.get("fsdp_config")
and not data["fsdp_config"].get("use_orig_params")
and data.get("plugins")
and any("SpectrumPlugin" in plugin for plugin in data["plugins"])
):
# would otherwise raise
# ValueError: Must flatten tensors with uniform `requires_grad` when `use_orig_params=False`
raise ValueError(
"FSDP + SpectrumPlugin cannot be used together when `use_orig_params=False` is set"
)
return data

View File

@@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
import torch

View File

@@ -6,6 +6,7 @@ See "LoRA: Low-Rank Adaptation of Large Language Models"
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
# pylint: disable=invalid-name
from typing import Callable

View File

@@ -1,4 +1,5 @@
"""Dequantization utilities for `bitsandbytes` integration."""
# pylint: disable=invalid-name,global-statement
import ctypes

View File

@@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
import torch
import triton
import triton.language as tl

View File

@@ -1,6 +1,7 @@
"""
HF Transformers MambaConfig
"""
from transformers import PretrainedConfig

Some files were not shown because too many files have changed in this diff Show More