Compare commits

...

56 Commits

Author SHA1 Message Date
Wing Lian
54bbc9bb72 set v0.9.2 version for tag
Some checks failed
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 124, 12.4.1, 3.11, 2.6.0) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2025-05-13 17:52:33 -04:00
Wing Lian
5aefebe1fe Activation checkpointing with offloading to disk with prefetch (#2663)
* offload activations to disk instead of CPU RAM

* add prefetch

* Disco :dance:

* include offload_disk in e2e test for AC

* document and make sure to cleanup

* fix annotation to match docs

* fix docs build

* address PR feedback
2025-05-13 17:06:31 -04:00
Wing Lian
5a36b6ff2d Atropos support (#2666) [skip ci]
* allow peft+liger+grpo and custom vllm serve for atropos support

* set trainer class for RL
2025-05-13 17:06:05 -04:00
NanoCode012
224da88fa2 fix: disable auto lora kernel if dropout nonzero (#2655) [skip ci]
* fix: disable auto lora kernel if dropout nonzero

* Add comment from PR feedback

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-13 17:05:20 -04:00
Wing Lian
493eb8e5c6 update doc and use P2P=LOC for brittle grpo test (#2649)
* update doc and skip brittle grpo test

* fix the path to run the multigpu tests

* increase timeout, use LOC instead of NVL

* typo

* use hf cache from s3 backed cloudfront

* mark grpo as flaky test dues to vllm start
2025-05-13 17:05:11 -04:00
Wing Lian
4780ac7c4d guard on deleting secrets from env (#2653) [skip ci] 2025-05-13 17:03:27 -04:00
Wing Lian
cf69de2eb9 Various fixes for CI, save_only_model for RL, prevent packing multiprocessing deadlocks (#2661)
* lean mistral ft tests, remove e2e torch 2.4.1 test

* make sure to pass save_only_model for RL

* more tests to make ci leaner, add cleanup to modal ci

* fix module for import in e2e tests

* use mp spawn to prevent deadlocks with packing

* make sure cleanup shell script is executable when cloned out
2025-05-13 17:03:08 -04:00
Wing Lian
27e3329273 .post1 version release for multipack fix
Some checks failed
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 124, 12.4.1, 3.11, 2.6.0) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2025-05-09 21:54:04 -04:00
Dan Saunders
27fec49083 don't sort multipack sampler (#2657)
* don't sort multipack sampler

* increased packing efficiency increases loss

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-09 21:53:29 -04:00
Wing Lian
8cda9e93c1 set version for v0.9.1
Some checks failed
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 124, 12.4.1, 3.11, 2.6.0) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2025-05-07 16:10:51 -04:00
Wing Lian
17d715c2b3 swap tinymodels that have safetensors for some ci tests (#2641) 2025-05-07 16:10:18 -04:00
xzuyn
f943306263 Add CAME Optimizer (#2385) 2025-05-07 16:10:17 -04:00
NanoCode012
3c8b9b33d6 fix(doc): clarify instruction to delinearize llama4 similar to cli doc (#2644) [skip ci] 2025-05-07 16:10:17 -04:00
NanoCode012
8b0c2a71ad Fix: improve error message on failed dataset load (#2637) [skip ci]
* fix(log): clarify error on dataset loading failed

* fix: add path for easy tracking of broken config

* fix: improve error message based on pr feedback
2025-05-07 16:10:17 -04:00
Wing Lian
493910559a Configurable embeddings upcast (#2621)
* fsdp embeddings should be float32 per comment

* patch peft to not upcast everything

* add tabs back to code check

* fix import

* add configurable option and fix check

* add check for dtypes

* move embeddings test to patch dir

* fix test

* fix comment and logic
2025-05-07 16:10:16 -04:00
Eric Meier
c54534dbfa Fix cut_cross_entropy plugin install (#2642) [skip ci] 2025-05-07 16:10:16 -04:00
Wing Lian
cae5cebb59 xformers attention with packing (#2619)
* xformers attention with packing

* wire up the patch

* fix xformers + packing validation

* fix warning

* reorder the packing check

* fix fp16 / bf16 reset when using fp16 with bf16 auto

* fix seq lens calc to drop hanging sequences

* handle xformers patch for inference too

* fix batch size setter

* fix xformers inference

* add colab callback to fix inference post train

* PR feedback
2025-05-07 16:10:16 -04:00
Wing Lian
fcbd7477d0 Multipack parallel bin packing (#2631)
* improve readability of multipack sampler

* parallel bin packing
fix error with lambda and pickling

make sure things are in float instead of np.float

* annotations and comments update

* support for configurable group and bin size for sample packing

* fix missing map back to original indices
2025-05-07 16:10:15 -04:00
Wing Lian
038db85a40 allow plugins to return their own dataset (#2617) [skip ci]
* allow plugins to return their own dataset

* add post_trainer_create and wire up

* add hook check

* address PR feedback:

* remove annotation causing circular import
2025-05-07 16:10:15 -04:00
NanoCode012
680dcc5a4d feat(doc): add split_thinking docs (#2613) [skip ci]
* feat(doc): add split_thinking docs

* fix: link config.qmd to conversation.qmd for split_thinking example

* update thinking => reasoning_content in messages format

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-07 16:10:15 -04:00
Wing Lian
fed5ca8254 bump liger dep to 0.5.9 (#2640) [skip ci]
* bump liger dep to 0.5.9

* also upgrade vllm to post1, and datasets to 3.5.1
2025-05-07 16:10:15 -04:00
mhenrichsen
7a2d017c88 Update lr_scheduler options in config.qmd to include additional scheduling strategies for improved training flexibility. (#2636) [skip ci] 2025-05-07 16:10:15 -04:00
Wing Lian
8c0303aa5e Print axolotl art if train is called outside of cli: (#2627) [skip ci] 2025-05-07 16:10:14 -04:00
Wing Lian
5d61169f7c fix dpo eval override to call grandparent instead of the broken super (#2628) [skip ci] 2025-05-07 16:10:14 -04:00
Wing Lian
e1586f7919 make sure gc_steps is used for all trainers (#2638) 2025-05-07 16:10:14 -04:00
Wing Lian
e4bf3ffb17 repop cache (#2639)
* repop cache

* pre-cache as a step

* fix the name

* add reason for pytest skipif

* restore pytorch matrix

* remove max-parallel now that we've optimized this a bit
2025-05-07 16:10:14 -04:00
mhenrichsen
30150fe1e1 Adds example for training a TTS model on top of a LLM. (#2614)
* Adds example for training a TTS model on top of a LLM.

* Update examples/orpheus/finetune.yml

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* Update examples/orpheus/finetune.yml

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* Update README.md to clarify GPU requirements for finetuning Orpheus TTS model

* Update finetune.yml to use the new base model canopylabs/orpheus-3b-0.1-pretrained

* Update finetune.yml and README.md for consistency and clarity

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-05-07 16:10:14 -04:00
Emmanuel Ferdman
7f7d7ade2e Fix logging deprecation warnings (#2623)
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-05-07 16:10:14 -04:00
Wing Lian
776cf70fe4 include multipack support for qwen3 family (#2622) 2025-05-07 16:10:14 -04:00
Wing Lian
8730951aba setup hf transfer too and fix auto bf16 when fp16 enabled (#2620) [skip ci] 2025-05-07 16:10:13 -04:00
Wing Lian
e72c11ad55 qwen3 and qwen3_moe support for liger kernels (#2612)
* qwen3 and qwen3_moe support for liger kernels

* fix moe module path

* fix: qwen3 liger input args and mlp

* fix: qwen3 input args and output class

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-05-07 16:10:13 -04:00
aitechguy
1a7978b960 remove keys to incoporate changes for the trl update (#2616) 2025-05-07 16:10:13 -04:00
Wing Lian
60b0d14f1d automatically set pad_to_sequence_len when use packing (#2607)
* automatically set pad_to_sequence_len when use packing

* update tests
2025-05-07 16:10:13 -04:00
NanoCode012
a7a40378f5 fix: run preview-docs only when md/qmd changes (#2606)
* fix: run preview-docs only when md/qmd changes

* feat: add quarto yaml based on PR feedback
2025-05-07 16:10:13 -04:00
Wing Lian
b50d35bec9 Logging config for colab (#2611)
* only configure logging on cli to play nicely with colab

* allow reloading the config on the fly from a dict

* make sure to use dict for yaml

* reuse existing function for load

* make cli args optional

* mps fix and respect max_steps
2025-05-07 16:10:13 -04:00
Wing Lian
bc6dfa6899 add missing __init__ for lr monkeypatch fix (#2609) 2025-05-07 16:10:13 -04:00
Dhruv Mullick
9d6e8af622 Add num_completions_to_print for trl and grpo (#2604) 2025-05-07 16:10:12 -04:00
Wing Lian
17b441248c use latest hf-xet and don't install vllm for torch 2.7.0 (#2603)
* use latest hf-xet and don't install vllm for torch 2.7.0

* fix runpod hub tests
2025-05-07 16:10:12 -04:00
Wing Lian
d49a4268b8 additional args for grpo config/trainer (#2598) 2025-05-07 16:10:12 -04:00
Wing Lian
1d6e931115 replace zero_only with simpler if statement (#2592) 2025-05-07 16:10:12 -04:00
Wing Lian
ff106ace44 ensure we pass axolotl extras to the Dockerfile so vllm is included in shipped images (#2599) 2025-05-07 16:10:12 -04:00
Wing Lian
24907533d1 don't automatically enable lora kernels for RL training (#2600) 2025-05-07 16:10:12 -04:00
Wing Lian
0e9d816d2e only import vllm serve cli if its being called (#2597) [skip ci] 2025-05-07 16:10:12 -04:00
Wing Lian
72f142186a Handle other reasoning trace dataset formats (#2591)
* Handle other reasoning trace dataset formats

* rename var to improve readability

* chore: refactor with comments

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-05-07 16:10:11 -04:00
Wing Lian
87726322bf upload the deepspeed json to wandb (#2593) [skip ci] 2025-05-07 16:10:11 -04:00
NanoCode012
ae8ae7534c feat: add qwen3 moe block for ds3 (#2596) [skip ci] 2025-05-07 16:10:11 -04:00
Wing Lian
ee00142cb5 patch to convert LR from tensor to float when using DS (#2595) [skip ci] 2025-05-07 16:10:11 -04:00
Aleksandr Dremov
097e7e3b5b Plugins create_lr_scheduler support (#2584)
* lr_scheduler support

* fix

* Update scheduler.py

* Update scheduler.py

* cfg handling

* black

* remove debug

* remove adding the axolotl cfg to the scheduler mixin

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-07 16:10:11 -04:00
Dan Saunders
c714958181 auto-enable lora kernels where possible (#2589)
* auto-enable lora kernels where possible

* test

* revert change to example yaml

* naming

* remove print

* slight logic change
2025-05-07 16:10:11 -04:00
NanoCode012
4402c293dc fix(doc): key used to point to url in multimodal doc (#2575) [skip ci] 2025-05-07 16:10:10 -04:00
Wing Lian
0d71f787a3 bump vllm==0.8.5 for qwen3 support (#2583) [skip ci] 2025-05-07 16:10:10 -04:00
Wing Lian
c337ca0872 support for qwen3 with lora kernels (#2588)
* support for qwen3 with lora kernels

* fix patch

* typo
2025-05-07 16:10:10 -04:00
Dan Saunders
f04f7cf5ad Fix eval + add smoke test (#2586)
* fix evaluate CLI

* add smoke test

* fix naming

* lint
2025-05-07 16:10:10 -04:00
Wing Lian
c64a951bc9 set config on the PluginManager for callback access (#2587) 2025-05-07 16:10:10 -04:00
Wing Lian
fc88cc56cb Post release fixes (#2581)
* fix missing kwarg on child

* make the runpod test shorter

* update docs

* rename runpod test json file

* typing fixes and ordering of doc
2025-05-07 16:10:10 -04:00
Wing Lian
e85cbb8645 remove torch 2.4.1 CI as part of support deprecation (#2582) 2025-05-07 16:10:10 -04:00
103 changed files with 3604 additions and 699 deletions

View File

@@ -22,12 +22,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""

View File

@@ -15,11 +15,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -35,7 +30,7 @@ jobs:
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras: vllm
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -67,6 +62,7 @@ jobs:
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: |
@@ -82,11 +78,6 @@ jobs:
strategy:
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -3,12 +3,13 @@ name: docker-multigpu-tests-biweekly
on:
pull_request:
paths:
- 'tests/e2e/multigpu/*.py'
- 'tests/e2e/multigpu/**.py'
- 'requirements.txt'
- 'setup.py'
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
- 'src/axolotl/utils/distributed.py'
workflow_dispatch:
schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
@@ -32,13 +33,6 @@ jobs:
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras: # no vllm support for 2.4.1
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -12,11 +12,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -70,11 +65,6 @@ jobs:
strategy:
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -4,6 +4,12 @@ on:
pull_request:
types: [opened, synchronize, reopened]
# Run the workflow only when one of these files changes
paths:
- '**/*.md' # any Markdown file
- '**/*.qmd' # any Quarto file
- '_quarto.yaml'
permissions:
checks: write
contents: write

View File

@@ -18,15 +18,102 @@ jobs:
env:
SKIP: no-commit-to-branch
preload-cache:
name: Preload HF cache
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0"]
timeout-minutes: 20
env:
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
- name: Install dependencies
run: |
pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -v tests/conftest.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest:
name: PyTest
runs-on: ubuntu-latest
needs: [preload-cache]
strategy:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
timeout-minutes: 20
steps:
@@ -106,13 +193,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -27,6 +27,9 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
env:
TRANSFORMERS_IS_CI: "yes"
jobs:
pre-commit:
name: pre-commit
@@ -41,29 +44,127 @@ jobs:
env:
SKIP: no-commit-to-branch
# preload-cache:
# name: Preload HF cache
# runs-on: ubuntu-latest
# strategy:
# fail-fast: false
# matrix:
# python_version: ["3.11"]
# pytorch_version: ["2.6.0"]
# timeout-minutes: 20
#
# env:
# AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
#
# steps:
# - name: Check out repository code
# uses: actions/checkout@v4
#
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
#
# - name: Restore Cache from S3
# id: hf-cache-restore-s3
# run: |
# mkdir -p /home/runner/.cache/huggingface/hub
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
#
# - name: Setup Python
# uses: actions/setup-python@v5
# with:
# python-version: ${{ matrix.python_version }}
# cache: 'pip' # caching pip dependencies
#
# - name: upgrade pip
# run: |
# pip3 install --upgrade pip
# pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
#
# - name: Install PyTorch
# run: |
# pip3 install torch==${{ matrix.pytorch_version }}
#
# - name: Install dependencies
# run: |
# pip3 show torch
# pip3 install --no-build-isolation -U -e .
# python scripts/unsloth_install.py | sh
# python scripts/cutcrossentropy_install.py | sh
# pip3 install -r requirements-dev.txt -r requirements-tests.txt
#
# - name: Make sure PyTorch version wasn't clobbered
# run: |
# python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
#
# - name: Ensure axolotl CLI was installed
# run: |
# axolotl --help
#
# - name: Pre-Download dataset fixture
# run: |
# huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
#
# - name: Run tests
# run: |
# pytest -v tests/conftest.py
#
# - name: Upload coverage to Codecov
# uses: codecov/codecov-action@v5
# with:
# token: ${{ secrets.CODECOV_TOKEN }}
# files: ./coverage.xml
# flags: unittests,pytorch-${{ matrix.pytorch_version }}
# fail_ci_if_error: false
#
# - name: cleanup pip cache
# run: |
# find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
#
# - name: Save HF cache
# id: hf-cache
# uses: actions/cache/save@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest:
name: PyTest
runs-on: ubuntu-latest
# needs: [preload-cache]
strategy:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0", "2.7.0"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
timeout-minutes: 20
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
- name: Setup Python
uses: actions/setup-python@v5
@@ -118,38 +219,35 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
# needs: [preload-cache]
strategy:
fail-fast: false
max-parallel: 1
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
timeout-minutes: 20
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
- name: Setup Python
uses: actions/setup-python@v5
@@ -196,15 +294,6 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
docker-e2e-tests-1st:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
@@ -258,12 +347,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -300,3 +383,43 @@ jobs:
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
docker-e2e-cleanup:
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [docker-e2e-tests]
strategy:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: vllm
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.cleanup

View File

@@ -57,8 +57,10 @@ async def handler(job):
logger.info("Training Complete.")
# Cleanup
del os.environ["WANDB_API_KEY"]
del os.environ["HF_TOKEN"]
if "WANDB_API_KEY" in os.environ:
del os.environ["WANDB_API_KEY"]
if "HF_TOKEN" in os.environ:
del os.environ["HF_TOKEN"]
runpod.serverless.start({"handler": handler, "return_aggregate_stream": True})

86
.runpod/test-input.json Normal file
View File

@@ -0,0 +1,86 @@
{
"input": {
"name": "quick_smoke_test_sft",
"user_id": "user",
"model_id": "llama-test",
"run_id": "llama-test",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"load_in_4bit": true,
"strict": false,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"split": "train[:10%]"
}
],
"val_set_size": 0.02,
"output_dir": "./outputs/lora-out",
"sequence_len": 4096,
"sample_packing": true,
"eval_sample_packing": false,
"pad_to_sequence_len": true,
"adapter": "qlora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": true,
"lora_modules_to_save": [
"embed_tokens",
"lm_head"
],
"gradient_accumulation_steps": 2,
"micro_batch_size": 1,
"num_epochs": 1,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": true,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"warmup_steps": 1,
"evals_per_epoch": 1,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"weight_decay": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>"
},
"max_steps": 20
},
"timeout": 100000
},
"config": {
"gpuTypeId": "NVIDIA GeForce RTX 4090",
"gpuCount": 1,
"containerDiskInGb": 200,
"env": [
{
"key": "TOKENIZER",
"value": ""
},
{
"key": "DISABLE_LOG_STATS",
"value": "true"
}
],
"allowedCudaVersions": [
"12.8",
"12.7",
"12.6",
"12.5",
"12.4"
]
}
}

View File

@@ -1,65 +1,70 @@
{
"input": {
"name": "quick_smoke_test_sft",
"user_id": "user",
"model_id": "llama-test",
"run_id": "llama-test",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"load_in_8bit": true,
"load_in_4bit": false,
"strict": false,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca"
"tests": [
{
"name": "quick_smoke_test_sft",
"input": {
"user_id": "user",
"model_id": "llama-test",
"run_id": "llama-test",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"load_in_4bit": true,
"strict": false,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"split": "train[:10%]"
}
],
"val_set_size": 0.02,
"output_dir": "./outputs/lora-out",
"sequence_len": 4096,
"sample_packing": true,
"eval_sample_packing": false,
"pad_to_sequence_len": true,
"adapter": "qlora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": true,
"lora_modules_to_save": [
"embed_tokens",
"lm_head"
],
"gradient_accumulation_steps": 2,
"micro_batch_size": 1,
"num_epochs": 1,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": true,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"warmup_steps": 1,
"evals_per_epoch": 1,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"weight_decay": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>"
},
"max_steps": 20
}
],
"val_set_size": 0.05,
"output_dir": "./outputs/lora-out",
"sequence_len": 4096,
"sample_packing": true,
"eval_sample_packing": false,
"pad_to_sequence_len": true,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": true,
"lora_modules_to_save": [
"embed_tokens",
"lm_head"
],
"gradient_accumulation_steps": 4,
"micro_batch_size": 2,
"num_epochs": 1,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": true,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"warmup_steps": 1,
"evals_per_epoch": 1,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"weight_decay": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>"
}
},
"timeout": 100000
},
},
"timeout": 100000
}
],
"config": {
"gpuTypeId": "NVIDIA GeForce RTX 4090",
"gpuCount": 1,

View File

@@ -124,7 +124,8 @@ quartodoc:
- utils.optimizers.adopt
- utils.data.pretraining
- utils.data.sft
- utils.gradient_checkpointing.unsloth
- utils.gradient_checkpointing.offload_cpu
- utils.gradient_checkpointing.offload_disk
- title: Schemas
desc: Pydantic data models for Axolotl config
contents:

0
cicd/__init__.py Normal file
View File

View File

@@ -18,7 +18,7 @@ pytest -v --durations=10 \
--cov-append
# Run patched tests excluding lora kernels with coverage append
pytest -v --durations=10 \
pytest --full-trace -vvv --durations=10 \
--ignore=tests/e2e/patched/lora_kernels \
/workspace/axolotl/tests/e2e/patched \
--cov=axolotl \

19
cicd/cleanup.py Normal file
View File

@@ -0,0 +1,19 @@
"""Modal app to run axolotl GPU cleanup"""
from .single_gpu import VOLUME_CONFIG, app, cicd_image, run_cmd
@app.function(
image=cicd_image,
timeout=60 * 60,
cpu=8.0,
memory=131072,
volumes=VOLUME_CONFIG,
)
def cleanup():
run_cmd("./cicd/cleanup.sh", "/workspace/axolotl")
@app.local_entrypoint()
def main():
cleanup.remote()

6
cicd/cleanup.sh Executable file
View File

@@ -0,0 +1,6 @@
#!/bin/bash
set -e
# cleanup old cache files for datasets processing and intermediate mappings
find /workspace/data/huggingface-cache/hub/datasets -name "cache-*" -type f -mtime +1 -exec rm {} \;
find /workspace/data/huggingface-cache/hub/datasets -name "*.lock" -type f -mtime +1 -exec rm {} \;

View File

@@ -1,75 +1,12 @@
"""Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import App, Image
cicd_path = pathlib.Path(__file__).parent.resolve()
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True,
gpu="A10G",
).env(df_args)
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit
from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
@app.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=60 * 60,
timeout=90 * 60, # 90 min
cpu=8.0,
memory=131072,
volumes=VOLUME_CONFIG,

66
cicd/single_gpu.py Normal file
View File

@@ -0,0 +1,66 @@
"""Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import App, Image
cicd_path = pathlib.Path(__file__).parent.resolve()
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True,
gpu="A10G",
).env(df_args)
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit

View File

@@ -19,7 +19,7 @@ coverage:
if_no_uploads: error
if_not_found: success
if_ci_failed: error
only_pulls: false
only_pulls: true
flags: null
paths: null
patch:

View File

@@ -32,6 +32,8 @@ tokenizer_legacy:
resize_token_embeddings_to_32x:
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
shrink_embeddings:
# Optional[bool] Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs
embeddings_skip_upcast:
# Whether to load the model with randomly initialized weights. Useful for
# pre-training a model from scratch or debugging purposes.
random_init_weights:
@@ -73,11 +75,12 @@ load_in_8bit: true
load_in_4bit:
# Use CUDA bf16
bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere
bf16: true # bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection. require >=ampere
# Use CUDA fp16
fp16: true
# Use CUDA tf32
tf32: true # require >=ampere
# Note: if bf16 is set to 'auto', and fp16 is set to true, we will prefer the explict fp16 setting
# No AMP (automatic mixed precision)
bfloat16: true # require >=ampere
@@ -184,6 +187,10 @@ datasets:
# adding a system turn with empty content.
drop_system_message:
# Optional[bool]. (for Qwen3 template only) Whether to split the assistant content based on a reasoning trace inside delimited tags
# See example at `docs/dataset-formats/conversation.qmd`
split_thinking:
# IMPORTANT: The following fields determine which parts of the conversation to train on.
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
# See examples at `docs/dataset-formats/conversation.qmd`
@@ -498,6 +505,7 @@ save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of eac
save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
save_total_limit: # Checkpoints saved at a time
save_only_model: # Save only the model weights, skipping the optimizer. Using this means you can't resume from checkpoints.
# Maximum number of iterations to train for. It precedes num_epochs which means that
# if both are set, num_epochs will not be guaranteed.
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
@@ -531,7 +539,7 @@ train_on_inputs: false
# Note that training loss may have an oscillating pattern with this enabled.
group_by_length: false
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk".
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false
# additional kwargs to pass to the trainer for gradient checkpointing
@@ -543,7 +551,7 @@ gradient_checkpointing: false
early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | 'linear' | 'cosine_with_restarts' | 'polynomial' | 'constant' | 'constant_with_warmup' | 'inverse_sqrt' | 'reduce_lr_on_plateau' | 'cosine_with_min_lr' | 'warmup_stable_decay' | empty for cosine
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
@@ -605,6 +613,7 @@ lr_div_factor: # Learning rate div factor
# - optimi_adamw
# - ao_adamw_8bit
# - ao_adamw_fp8
# - came_pytorch
optimizer:
# Dictionary of arguments to pass to the optimizer
optim_args:

View File

@@ -196,6 +196,34 @@ datasets:
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
:::
8. (For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
```yaml
datasets:
- path: ...
type: chat_template
chat_template: qwen3
split_thinking: true
```
For example, a content can look like:
```json
{
"content": "<think>Some thinking outputs</think>Output after thinking."
}
```
After split, it will look like:
```json
{
"reasoning_content": "Some thinking outputs",
"content": "Output after thinking..."
}
```
## sharegpt
::: {.callout-important}

View File

@@ -164,7 +164,7 @@ Here is an example of a multi-modal dataset:
{
"role": "user",
"content": [
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
{"type": "text", "text": "Describe this image in detail."}
]
},

View File

@@ -34,3 +34,5 @@ We provide a script to delinearize Llama 4 linearized models into regular Huggin
```bash
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
```
Note: This only works with the non-quantized linearized model. If you have an adapter, merge it with the *non-quantized linearized* model before delinearizing.

341
examples/orpheus/README.md Normal file
View File

@@ -0,0 +1,341 @@
# Finetuning LLMs to output audio
In this example, we finetune Orpcanopylabs/orpheus-tts-0.1-pretrained (a LLaMA 3.2 3b model) to output audio.
The `finetune.yml` withe current settings will run on any Nvidia GPU with 45GB VRAM or more. If you adjust the batch size it can easily run on any GPU under 24GB.
## Dataset pre-processing for pre-training
If you are adding another voice in English, please jump ahead to finetuning pre-processing.
For this to work, we need to preprocess our dataset. Since we are expecting to output audio, we will need to add tokens to the tokenizer.
Using this code, it will download the SNAC model and add the correct tokens and upload the final dataset.
```python
import torch
from snac import SNAC
from datasets import load_dataset
from huggingface_hub import snapshot_download
from datasets import load_dataset
import random
import torchaudio.transforms as T
from transformers import AutoTokenizer
import os
my_original_dataset_name = "<huggingface-id-of-dataset-that-we-want-to-preprocess>"
name_to_push_dataset_to = "<huggingface-id-of-where-to-save-dataset>"
dsn = my_original_dataset_name
snapshot_download(
repo_id=dsn,
repo_type="dataset",
revision="main",
max_workers=64,
)
ds = load_dataset(dsn, split="train")
ds_sample_rate = ds[0]["audio"]["sampling_rate"]
model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
model = model.to("mps")
def tokenise_audio(waveform):
waveform = torch.from_numpy(waveform).unsqueeze(0)
waveform = waveform.to(dtype=torch.float32)
resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)
waveform = resample_transform(waveform)
waveform = waveform.unsqueeze(0).to("cuda")
#generate the codes from snac
with torch.inference_mode():
codes = model.encode(waveform)
all_codes = []
for i in range(codes[0].shape[1]):
all_codes.append(codes[0][0][i].item()+128266)
all_codes.append(codes[1][0][2*i].item()+128266+4096)
all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))
return all_codes
def add_codes(example):
# Always initialize codes_list to None
codes_list = None
try:
answer_audio = example.get("audio")
# If there's a valid audio array, tokenise it
if answer_audio and "array" in answer_audio:
audio_array = answer_audio["array"]
codes_list = tokenise_audio(audio_array)
except Exception as e:
print(f"Skipping row due to error: {e}")
# Keep codes_list as None if we fail
example["codes_list"] = codes_list
return example
ds = ds.map(add_codes, remove_columns=["audio"])
#@title Load Tokenizer
tokeniser_length = 128256
start_of_text = 128000
end_of_text = 128009
start_of_speech = tokeniser_length + 1
end_of_speech = tokeniser_length + 2
start_of_human = tokeniser_length + 3
end_of_human = tokeniser_length + 4
start_of_ai = tokeniser_length + 5
end_of_ai = tokeniser_length + 6
pad_token = tokeniser_length + 7
audio_tokens_start = tokeniser_length + 10
tokenizer_name = "canopylabs/orpheus-3b-0.1-pretrained"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
num_proc = os.cpu_count() - 2
ds = ds.filter(lambda x: x["codes_list"] is not None)
ds = ds.filter(lambda x: len(x["codes_list"]) > 0)
#@title Create Input Ids
def remove_duplicate_frames(example):
vals = example["codes_list"]
if len(vals) % 7 != 0:
raise ValueError("Input list length must be divisible by 7")
result = vals[:7]
removed_frames = 0
for i in range(7, len(vals), 7):
current_first = vals[i]
previous_first = result[-7]
if current_first != previous_first:
result.extend(vals[i:i+7])
else:
removed_frames += 1
example["codes_list"] = result
return example
ds = ds.map(remove_duplicate_frames, num_proc=num_proc)
def create_input_ids(example):
text_ids = tokenizer.encode({example['text']}, add_special_tokens=True)
text_ids.append(end_of_text)
example["text_tokens"] = text_ids
input_ids = (
[start_of_human]
+ example["text_tokens"]
+ [end_of_human]
+ [start_of_ai]
+ [start_of_speech]
+ example["codes_list"]
+ [end_of_speech]
+ [end_of_ai]
)
example["input_ids"] = input_ids
example["labels"] = input_ids
example["attention_mask"] = [1] * len(input_ids)
return example
ds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=["text", "codes_list"])
#@title Remove unnecessary columns
columns_to_keep = ["input_ids", "labels", "attention_mask"]
columns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]
ds = ds.remove_columns(columns_to_remove)
ds.push_to_hub(name_to_push_dataset_to)
```
## Finetune pre-processing
Use this code to add a new voice.
```python
import torch
from snac import SNAC
from datasets import load_dataset
from huggingface_hub import snapshot_download
from datasets import load_dataset
import random
import torchaudio.transforms as T
from transformers import AutoTokenizer
import os
my_original_dataset_name = "<huggingface-id-of-dataset-that-we-want-to-preprocess>"
name_to_push_dataset_to = "<huggingface-id-of-where-to-save-dataset>"
dsn = my_original_dataset_name
snapshot_download(
repo_id=dsn,
repo_type="dataset",
revision="main",
max_workers=64,
)
ds = load_dataset(dsn, split="train")
ds_sample_rate = ds[0]["audio"]["sampling_rate"]
model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
model = model.to("mps")
def tokenise_audio(waveform):
waveform = torch.from_numpy(waveform).unsqueeze(0)
waveform = waveform.to(dtype=torch.float32)
resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)
waveform = resample_transform(waveform)
waveform = waveform.unsqueeze(0).to("cuda")
#generate the codes from snac
with torch.inference_mode():
codes = model.encode(waveform)
all_codes = []
for i in range(codes[0].shape[1]):
all_codes.append(codes[0][0][i].item()+128266)
all_codes.append(codes[1][0][2*i].item()+128266+4096)
all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))
return all_codes
def add_codes(example):
# Always initialize codes_list to None
codes_list = None
try:
answer_audio = example.get("audio")
# If there's a valid audio array, tokenise it
if answer_audio and "array" in answer_audio:
audio_array = answer_audio["array"]
codes_list = tokenise_audio(audio_array)
except Exception as e:
print(f"Skipping row due to error: {e}")
# Keep codes_list as None if we fail
example["codes_list"] = codes_list
return example
ds = ds.map(add_codes, remove_columns=["audio"])
#@title Load Tokenizer
tokeniser_length = 128256
start_of_text = 128000
end_of_text = 128009
start_of_speech = tokeniser_length + 1
end_of_speech = tokeniser_length + 2
start_of_human = tokeniser_length + 3
end_of_human = tokeniser_length + 4
start_of_ai = tokeniser_length + 5
end_of_ai = tokeniser_length + 6
pad_token = tokeniser_length + 7
audio_tokens_start = tokeniser_length + 10
tokenizer_name = "canopylabs/orpheus-3b-0.1-pretrained"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
num_proc = os.cpu_count() - 2
ds = ds.filter(lambda x: x["codes_list"] is not None)
ds = ds.filter(lambda x: len(x["codes_list"]) > 0)
#@title Create Input Ids
def remove_duplicate_frames(example):
vals = example["codes_list"]
if len(vals) % 7 != 0:
raise ValueError("Input list length must be divisible by 7")
result = vals[:7]
removed_frames = 0
for i in range(7, len(vals), 7):
current_first = vals[i]
previous_first = result[-7]
if current_first != previous_first:
result.extend(vals[i:i+7])
else:
removed_frames += 1
example["codes_list"] = result
return example
ds = ds.map(remove_duplicate_frames, num_proc=num_proc)
tok_info = '''*** HERE you can modify the text prompt
i.e. if you wanted a multispeaker model like canopylabs/orpheus-3b-0.1-ft, you can pass:
f"{example["source"]}: {example["text"]}", as is passed.
'''
print(tok_info)
def create_input_ids(example):
text_ids = tokenizer.encode(f"{example['speaker_id']}: {example['text']}", add_special_tokens=True)
text_ids.append(end_of_text)
example["text_tokens"] = text_ids
input_ids = (
[start_of_human]
+ example["text_tokens"]
+ [end_of_human]
+ [start_of_ai]
+ [start_of_speech]
+ example["codes_list"]
+ [end_of_speech]
+ [end_of_ai]
)
example["input_ids"] = input_ids
example["labels"] = input_ids
example["attention_mask"] = [1] * len(input_ids)
return example
ds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=["text", "codes_list"])
#@title Remove unnecessary columns
columns_to_keep = ["input_ids", "labels", "attention_mask"]
columns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]
ds = ds.remove_columns(columns_to_remove)
ds.push_to_hub(name_to_push_dataset_to)
```
## Training
After preprocessing is done, fill out the blanks in finetune.yml and simply run `axolotl train finetune.yml`
## Inference
For inference, please refer to the original [orpheus github](https://github.com/canopyai/Orpheus-TTS/tree/main).

View File

@@ -0,0 +1,52 @@
base_model: canopylabs/orpheus-3b-0.1-pretrained
hub_model_id: <your-hub-model-id>
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: true
datasets:
- path: <your-hf-dataset-id>
type: # leave empty to load pre-tokenized
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 4
num_epochs: 3
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5
bf16: auto
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 20
evals_per_epoch: 5
saves_per_epoch: 5
weight_decay: 0.05
special_tokens:
pad_token: <custom_token_7>

View File

@@ -6,19 +6,20 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.8
liger-kernel==0.5.9
# END section
packaging==23.2
huggingface_hub==0.31.0
peft==0.15.2
transformers==4.51.3
tokenizers>=0.21.1
accelerate==1.6.0
datasets==3.5.0
datasets==3.5.1
deepspeed>=0.15.4
trl==0.17.0
hf_xet==1.0.0
hf_xet==1.1.0
hqq==0.2.5
optimum==1.16.2

View File

@@ -67,13 +67,13 @@ def parse_requirements(extras_require_map):
if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
extras_require_map["vllm"] = ["vllm==0.8.4"]
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append(
"xformers==0.0.29.post2"
) # vllm needs post2 w torch 2.6
extras_require_map["vllm"] = ["vllm==0.8.4"]
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
@@ -142,6 +142,7 @@ extras_require = {
"apollo-torch",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
"came_pytorch==0.1.3",
],
"ray": [
"ray[train]",

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.9.0"
__version__ = "0.9.2"

View File

@@ -2,4 +2,7 @@
import os
from axolotl.logging_config import configure_logging
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
configure_logging()

View File

@@ -82,6 +82,12 @@ class VllmServeCliArgs:
"hardware support this feature."
},
)
serve_module: Optional[str] = field(
default=None,
metadata={
"help": "Module to serve. If not set, the default module will be used."
},
)
@dataclass

View File

@@ -16,8 +16,15 @@ AXOLOTL_LOGO = """
@@@@ @@@@@@@@@@@@@@@@
"""
HAS_PRINTED_LOGO = False
def print_axolotl_text_art():
"""Prints axolotl ASCII art."""
global HAS_PRINTED_LOGO # pylint: disable=global-statement
if HAS_PRINTED_LOGO:
return
if is_main_process():
HAS_PRINTED_LOGO = True
print(AXOLOTL_LOGO)

View File

@@ -8,9 +8,6 @@ from accelerate.commands.config import config_args
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from axolotl.logging_config import configure_logging
configure_logging()
LOG = logging.getLogger(__name__)

View File

@@ -5,6 +5,7 @@ import logging
import os
import tempfile
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Union
from urllib.parse import urlparse
@@ -152,7 +153,15 @@ def prepare_plugins(cfg: DictDefault):
plugin_manager.register(plugin_name)
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
def plugin_set_cfg(cfg: DictDefault):
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
plugin_manager.cfg = cfg
def load_cfg(
config: str | Path | DictDefault = Path("examples/"), **kwargs
) -> DictDefault:
"""
Loads the `axolotl` configuration stored at `config`, validates it, and performs
various setup.
@@ -164,13 +173,24 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
Returns:
`DictDefault` mapping configuration keys to values.
"""
config = check_remote_config(config)
if Path(config).is_dir():
config = choose_config(Path(config))
if isinstance(config, (str, Path)):
config = check_remote_config(config)
if Path(config).is_dir():
config = choose_config(Path(config))
# Load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# Load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
else:
cfg = config
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
temp_file.write(yaml.dump(config.to_dict()))
temp_file.close()
cfg.axolotl_config_path = temp_file.name
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
@@ -184,8 +204,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
else:
cfg[k] = kwargs[k]
cfg.axolotl_config_path = config
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
@@ -213,5 +231,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
setup_wandb_env_vars(cfg)
setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
plugin_set_cfg(cfg)
return cfg

View File

@@ -1,6 +1,7 @@
"""CLI to run evaluation on a model."""
import logging
import os
from pathlib import Path
from typing import Union
@@ -14,6 +15,7 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.evaluate import evaluate
from axolotl.utils import patch_optimized_env
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
@@ -29,10 +31,14 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: CLI arguments.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env()
# pylint: disable=duplicate-code
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -28,9 +28,8 @@ from axolotl.cli.utils import (
fetch_from_github,
filter_none_kwargs,
)
from axolotl.cli.vllm_serve import do_vllm_serve
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils import patch_optimized_env
from axolotl.utils.schemas.config import AxolotlInputConfig
@@ -56,6 +55,8 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
patch_optimized_env()
if cloud:
from axolotl.cli.cloud import do_cli_preprocess
@@ -101,7 +102,7 @@ def train(
config options.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
patch_optimized_env()
if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False
@@ -327,6 +328,8 @@ def fetch(directory: str, dest: Optional[str]) -> None:
@add_options_from_dataclass(VllmServeCliArgs)
@filter_none_kwargs
def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
from axolotl.cli.vllm_serve import do_vllm_serve
do_vllm_serve(config, cli_args)

View File

@@ -18,6 +18,7 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import disable_datasets_caching
@@ -47,7 +48,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
with disable_datasets_caching():
if cfg.rl:
plugin_manager = PluginManager.get_instance()
if plugin_manager.load_datasets(cfg, preprocess=True):
pass
elif cfg.rl:
load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -18,7 +18,7 @@ from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils import patch_optimized_env
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
@@ -36,17 +36,20 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
cli_args: Training-specific CLI arguments.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
patch_optimized_env()
print_axolotl_text_art()
check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
plugin_manager = PluginManager.get_instance()
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
if not dataset_meta:
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)

View File

@@ -20,11 +20,9 @@ from transformers import (
ProcessorMixin,
)
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_processor, load_tokenizer
configure_logging()
LOG = logging.getLogger(__name__)

View File

@@ -6,7 +6,6 @@ from pathlib import Path
from typing import Union
from trl.scripts.vllm_serve import ScriptArguments
from trl.scripts.vllm_serve import main as vllm_serve_main
from axolotl.cli.config import load_cfg
@@ -28,6 +27,9 @@ def do_vllm_serve(
cfg = load_cfg(config)
model = cfg.base_model
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
tensor_parallel_size = (
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
)

View File

@@ -11,5 +11,6 @@ MOE_ARCH_BLOCK = {
],
"mixtral": "MixtralSparseMoeBlock",
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
}

View File

@@ -47,7 +47,8 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
def load_datasets(
*,
cfg: DictDefault,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
debug: bool = False,
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets, calling
@@ -56,6 +57,7 @@ def load_datasets(
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Command-specific CLI arguments.
debug: Whether to print out tokenization of sample
Returns:
Dataclass with fields for training and evaluation datasets and the computed
@@ -64,7 +66,8 @@ def load_datasets(
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = (
hasattr(cli_args, "iterable")
cli_args
and hasattr(cli_args, "iterable")
and cli_args.iterable is not None
and cli_args.iterable
)
@@ -76,20 +79,25 @@ def load_datasets(
preprocess_iterable=preprocess_iterable,
)
if (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
):
if ( # pylint: disable=too-many-boolean-expressions
cli_args
and (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
)
) or debug:
LOG.info("check_dataset_labels...")
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
num_examples = cli_args.debug_num_examples if cli_args else 1
text_only = cli_args.debug_text_only if cli_args else False
train_samples = sample_dataset(train_dataset, num_examples)
check_dataset_labels(
train_samples,
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
num_examples=num_examples,
text_only=text_only,
)
LOG.info("printing prompters...")

View File

@@ -21,6 +21,7 @@ import importlib.util
import inspect
import logging
import math
import os
import sys
from abc import abstractmethod
from pathlib import Path
@@ -60,6 +61,7 @@ from axolotl.core.training_args import (
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
@@ -71,6 +73,7 @@ from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
@@ -114,6 +117,8 @@ class TrainerBuilderBase(abc.ABC):
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
patch_trainer_get_lr()
@property
def model_ref(self):
return self._model_ref
@@ -165,6 +170,9 @@ class TrainerBuilderBase(abc.ABC):
)
)
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
@@ -246,9 +254,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -290,6 +295,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer))
if any("COLAB_" in key for key in os.environ):
ColabCallback = colab_inference_post_train_callback(trainer)
callbacks.append(ColabCallback(self.cfg))
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks
@@ -485,7 +494,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = (
total_num_steps if self.cfg.max_steps else -1
self.cfg.max_steps if self.cfg.max_steps else -1
)
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs["per_device_train_batch_size"] = (
@@ -699,6 +708,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
optimizer_cls = ADOPT
adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "came_pytorch":
from came_pytorch import CAME
optimizer_cls = CAME
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999)
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
adam_kwargs["betas"] = (beta1, beta2, beta3)
adam_kwargs["eps"] = (eps1, eps2)
optimizer_kwargs.update(adam_kwargs)
# Parse any additional optimizer args from config
if self.cfg.optim_args:
@@ -1034,6 +1057,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args_kwargs["save_only_model"] = self.cfg.save_only_model
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
@@ -1163,6 +1188,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters.keys():
dpo_trainer_kwargs["tokenizer"] = self.tokenizer

View File

@@ -114,6 +114,8 @@ class AxolotlTrainer(
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
sequential=self.args.sample_packing_sequentially,
drop_last=True,
)

View File

@@ -177,12 +177,8 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
res["chosen_labels"] = res["chosen_labels"][1:]
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
res["rejected_labels"] = res["rejected_labels"][1:]
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
return res
@@ -251,7 +247,9 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
)
# Base evaluation
initial_output = super().evaluation_loop(
initial_output = super( # pylint: disable=bad-super-call
DPOTrainer, self
).evaluation_loop(
dataloader,
description,
prediction_loss_only,

View File

@@ -63,6 +63,7 @@ class GRPOStrategy:
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
grpo_args_kwargs["log_completions"] = trl.log_completions
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights
@@ -70,6 +71,13 @@ class GRPOStrategy:
if trl.scale_rewards is not None:
grpo_args_kwargs["scale_rewards"] = trl.scale_rewards
if trl.loss_type is not None:
grpo_args_kwargs["loss_type"] = trl.loss_type
if trl.mask_truncated_completions is not None:
grpo_args_kwargs["mask_truncated_completions"] = (
trl.mask_truncated_completions
)
if trl.temperature is not None:
grpo_args_kwargs["temperature"] = trl.temperature
if trl.top_p is not None:
@@ -85,6 +93,11 @@ class GRPOStrategy:
grpo_args_kwargs["num_iterations"] = trl.num_iterations
if trl.epsilon is not None:
grpo_args_kwargs["epsilon"] = trl.epsilon
if trl.epsilon_high is not None:
grpo_args_kwargs["epsilon_high"] = trl.epsilon_high
if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
return grpo_args_kwargs

View File

@@ -3,9 +3,10 @@
import logging
import torch
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
from transformers.trainer import Trainer
from axolotl.integrations.base import PluginManager
from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr,
@@ -25,9 +26,9 @@ class SchedulerMixin(Trainer):
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
) -> LRScheduler:
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
@@ -47,7 +48,16 @@ class SchedulerMixin(Trainer):
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if self.args.alternate_lr_scheduler_type == "one_cycle":
plugin_manager = PluginManager.get_instance()
lr_scheduler: LRScheduler | None = plugin_manager.create_lr_scheduler(
trainer=self,
optimizer=optimizer,
num_training_steps=num_training_steps
)
if lr_scheduler is not None:
LOG.info(f"Using plugin-created lr_scheduler: {lr_scheduler}")
self.lr_scheduler = lr_scheduler
elif self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {}
@@ -110,4 +120,4 @@ class SchedulerMixin(Trainer):
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler
return self.lr_scheduler # type: ignore

View File

@@ -1,6 +1,7 @@
"""Module for ReLoRA trainer"""
import torch
from torch.optim.lr_scheduler import LRScheduler
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.monkeypatch.relora import ReLoRAScheduler
@@ -19,9 +20,11 @@ class ReLoRATrainer(AxolotlTrainer):
self,
num_training_steps: int,
optimizer: torch.optim.Optimizer | None = None,
):
) -> LRScheduler:
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
lr_scheduler: LRScheduler = super().create_scheduler(
num_training_steps, optimizer
)
if self.args.relora_steps:
warmup_steps = (
@@ -30,7 +33,7 @@ class ReLoRATrainer(AxolotlTrainer):
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler(
self.lr_scheduler = ReLoRAScheduler( # type: ignore
optimizer,
lr_scheduler,
self.args.relora_steps,
@@ -38,6 +41,6 @@ class ReLoRATrainer(AxolotlTrainer):
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
self.lr_scheduler = lr_scheduler # type: ignore
return self.lr_scheduler
return self.lr_scheduler # type: ignore

View File

@@ -11,20 +11,19 @@ from accelerate.logging import get_logger
from datasets import Dataset
from transformers.trainer import Trainer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.train import (
TrainDatasetMeta,
setup_model_and_tokenizer,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = get_logger("axolotl.evaluate")
LOG = get_logger(__name__)
def evaluate_dataset(
@@ -75,37 +74,22 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
Returns:
Dictionary mapping metric names to their values.
"""
# pylint: disable=duplicate-code
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
# Load processor for multimodal models if needed
processor = None
if cfg.is_multimodal:
processor = load_processor(cfg, tokenizer)
# Load tokenizer, processor and model
LOG.debug("loading model for evaluation...")
model, tokenizer, _, processor = setup_model_and_tokenizer(cfg)
# Get datasets
# pylint: disable=duplicate-code
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Load model
LOG.debug("loading model for evaluation...")
model, _ = load_model(cfg, tokenizer, processor=processor)
# Set up trainer
trainer = setup_trainer(
cfg,
cfg=cfg,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model=(model, None, None), # No need for model_ref or peft_config
model=model,
tokenizer=tokenizer,
processor=processor,
total_num_steps=total_num_steps,

View File

@@ -24,6 +24,9 @@ import logging
from typing import OrderedDict
import torch
from torch.optim.lr_scheduler import LRScheduler
from axolotl.utils.dict import DictDefault
class BasePlugin:
@@ -35,13 +38,15 @@ class BasePlugin:
Methods:
register(cfg): Registers the plugin with the given configuration.
load_datasets(cfg): Loads and preprocesses the dataset for training.
pre_model_load(cfg): Performs actions before the model is loaded.
post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied.
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
post_trainer_create(cfg, trainer): Performs actions after the trainer is created.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler.
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training.
"""
@@ -62,20 +67,32 @@ class BasePlugin:
None
"""
def get_input_args(self):
def get_input_args(self) -> str | None:
"""
Returns a pydantic model for the plugin's input arguments.
"""
def load_datasets(self, cfg: DictDefault, preprocess: bool = False):
"""
Loads and preprocesses the dataset for training.
Args:
cfg: The configuration for the plugin.
preprocess: Whether this is the preprocess step of the datasets.
Returns:
dataset_meta: The metadata for the training dataset.
"""
def pre_model_load(self, cfg): # pylint: disable=unused-argument
"""
Performs actions before the model is loaded.
Parameters:
cfg (dict): The configuration for the plugin.
Args:
cfg (dict): The configuration for the plugin.
Returns:
None
None
"""
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
@@ -90,86 +107,99 @@ class BasePlugin:
"""
Performs actions after the model is loaded.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
None
"""
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions before LoRA weights are loaded.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
None
"""
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after LoRA weights are loaded.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
None
"""
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
"""
Returns a custom class for the trainer.
Parameters:
cfg (dict): The global axolotl configuration.
Args:
cfg (dict): The global axolotl configuration.
Returns:
class: The class for the trainer.
class: The class for the trainer.
"""
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
"""
Performs actions after the trainer is created.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
None
"""
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
"""
Creates and returns an optimizer for training.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
object: The created optimizer.
object: The created optimizer.
"""
def create_lr_scheduler(
self, cfg, trainer, optimizer
): # pylint: disable=unused-argument
self, cfg, trainer, optimizer, num_training_steps
) -> LRScheduler | None: # pylint: disable=unused-argument
"""
Creates and returns a learning rate scheduler.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
num_training_steps (int): Total number of training steps
Returns:
object: The created learning rate scheduler.
object (LRScheduler): The created learning rate scheduler.
"""
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
"""
setup callbacks before creating the trainer.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []
@@ -180,12 +210,12 @@ class BasePlugin:
Adds callbacks to the trainer after creating the trainer.
This is useful for callbacks that require access to the model or trainer.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
List[callable]: A list of callback functions to be added
List[callable]: A list of callback functions to be added
"""
return []
@@ -193,23 +223,23 @@ class BasePlugin:
"""
Performs actions after training is complete.
Parameters:
cfg (dict): The axolotl configuration
model (object): The loaded model.
Args:
cfg (dict): The axolotl configuration
model (object): The loaded model.
Returns:
None
None
"""
def post_train_unload(self, cfg): # pylint: disable=unused-argument
"""
Performs actions after training is complete and the model is unloaded.
Parameters:
cfg (dict): The configuration for the plugin.
Args:
cfg (dict): The configuration for the plugin.
Returns:
None
None
"""
@@ -270,6 +300,7 @@ class PluginManager:
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
_instance = None
_cfg = None
def __new__(cls):
"""
@@ -277,7 +308,9 @@ class PluginManager:
"""
if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins = collections.OrderedDict()
cls._instance.plugins: OrderedDict[str, BasePlugin] = (
collections.OrderedDict()
)
return cls._instance
@staticmethod
@@ -290,6 +323,14 @@ class PluginManager:
PluginManager()
return PluginManager._instance # type: ignore
@property
def cfg(self):
return self._cfg
@cfg.setter
def cfg(self, cfg):
self._cfg = cfg
def register(self, plugin_name: str):
"""
Registers a new plugin by its name.
@@ -325,6 +366,27 @@ class PluginManager:
input_args.append(input_args_from_plugin)
return input_args
def load_datasets(self, cfg, preprocess: bool = False):
"""
Calls the load_datasets method of each registered plugin.
Args:
cfg: The configuration for the plugins.
preprocess : Whether this is preprocess step of the datasets.
Returns:
dataset_meta: The dataset metadata loaded from all registered plugins.
"""
return_ds_meta = None
for plugin in self.plugins.values():
dataset_meta = plugin.load_datasets(cfg, preprocess)
if dataset_meta is not None:
if return_ds_meta is None:
return_ds_meta = dataset_meta
else:
raise RuntimeError("Multiple plugins loaded datasets")
return return_ds_meta
def pre_model_load(self, cfg):
"""
Calls the pre_model_load method of all registered plugins.
@@ -409,29 +471,43 @@ class PluginManager:
return trainer_cls
return None
def create_optimizer(self, cfg, trainer):
def post_trainer_create(self, cfg, trainer):
"""
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
Calls the post_trainer_create method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
Returns:
None
"""
for plugin in self.plugins.values():
plugin.post_trainer_create(cfg, trainer)
def create_optimizer(self, trainer):
"""
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
Parameters:
trainer (object): The trainer object for training.
Returns:
object: The created optimizer, or None if none was found.
"""
for plugin in self.plugins.values():
optimizer = plugin.create_optimizer(cfg, trainer)
optimizer = plugin.create_optimizer(self.cfg, trainer)
if optimizer is not None:
return optimizer
return None
def create_lr_scheduler(self, cfg, trainer, optimizer):
def create_lr_scheduler(
self, trainer, optimizer, num_training_steps
) -> LRScheduler | None:
"""
Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
@@ -439,7 +515,12 @@ class PluginManager:
object: The created learning rate scheduler, or None if none was found.
"""
for plugin in self.plugins.values():
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
scheduler: LRScheduler | None = plugin.create_lr_scheduler(
self.cfg,
trainer=trainer,
optimizer=optimizer,
num_training_steps=num_training_steps,
)
if scheduler is not None:
return scheduler
return None

View File

@@ -25,7 +25,7 @@ import torch
from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version
from axolotl.utils.distributed import zero_only
from axolotl.utils.distributed import is_main_process
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
@@ -76,7 +76,7 @@ class CutCrossEntropyPlugin(BasePlugin):
cce_patch,
)
with zero_only():
if is_main_process(use_environ=True):
LOG.info(
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
)

View File

@@ -37,6 +37,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
train_on_eos=None,
train_on_eot=None,
eot_tokens=None,
split_thinking: bool | None = False,
logprobs_field="logprobs",
gen_temperature=1.0,
kd_temperature=1.0,
@@ -54,6 +55,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
train_on_eos=train_on_eos,
train_on_eot=train_on_eot,
eot_tokens=eot_tokens,
split_thinking=split_thinking,
)
@property

View File

@@ -23,8 +23,8 @@ import logging
import sys
from axolotl.integrations.base import BasePlugin
from axolotl.utils.distributed import is_main_process
from ...utils.distributed import zero_only
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
from .utils import patch_with_compile_disable
@@ -85,7 +85,7 @@ class LigerPlugin(BasePlugin):
kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation
with zero_only():
if is_main_process(use_environ=True):
LOG.info(
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
)
@@ -151,6 +151,30 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3":
from axolotl.integrations.liger.models.qwen3 import (
apply_liger_kernel_to_qwen3,
)
apply_liger_kernel_to_qwen3(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_moe":
from axolotl.integrations.liger.models.qwen3_moe import (
apply_liger_kernel_to_qwen3_moe,
)
apply_liger_kernel_to_qwen3_moe(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
else:
logging.warning(
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."

View File

@@ -0,0 +1,160 @@
"""
Liger FLCE for Qwen3. Based on transformers v4.51.3.
"""
import sys
from typing import Optional, Tuple, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_liger_kernel_to_qwen3(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> None:
# pylint: disable=duplicate-code
"""
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3.modeling_qwen3 # noqa: F401 # pylint: disable=unused-import
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
modeling_qwen3 = sys.modules["transformers.models.qwen3.modeling_qwen3"]
if rms_norm:
modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
if glu_activation:
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
if layer_norm:
modeling_qwen3.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_qwen3.Qwen3ForCausalLM.forward = lce_forward

View File

@@ -0,0 +1,191 @@
"""
Liger FLCE for Qwen3 MoE. Based on transformers v4.51.3.
"""
import sys
from copy import deepcopy
from typing import List, Optional, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
from transformers.models.qwen3_moe.modeling_qwen3_moe import load_balancing_loss_func
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> MoeCausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(
loss.device
) # make sure to reside in the same device
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_liger_kernel_to_qwen3_moe(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> None:
# pylint: disable=duplicate-code
"""
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3_moe.modeling_qwen3_moe # noqa: F401 # pylint: disable=unused-import
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
modeling_qwen3_moe = sys.modules["transformers.models.qwen3_moe.modeling_qwen3_moe"]
if rms_norm:
modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
if glu_activation:
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
"Accepts intermediate_size to pass to LigerSwiGLUMLP"
# clone config to avoid modifying the original
config = deepcopy(config)
if intermediate_size:
setattr(config, "intermediate_size", intermediate_size)
return LigerSwiGLUMLP(config, **kwargs)
modeling_qwen3_moe.Qwen3MoeMLP = _liger_swiglu_mlp_wrapper
if layer_norm:
modeling_qwen3_moe.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = lce_forward

View File

@@ -0,0 +1,19 @@
"""
attention module for attention monkeypatches
"""
from transformers.integrations.flash_attention import flash_attention_forward
def patch_xformers_attn_over_fa2():
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from .xformers import xformers_attention_forward
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = xformers_attention_forward
def unpatch_xformers_attn_over_fa2():
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward()

View File

@@ -12,10 +12,8 @@ import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
configure_logging()
LOG = get_logger(__name__)

View File

@@ -0,0 +1,160 @@
"""
xformers attention implementation for packing
"""
from typing import Optional
import torch
import xformers
import xformers.ops.fmha
from transformers.modeling_flash_attention_utils import (
_upad_input,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
xformers_attention = xformers.ops.fmha.memory_efficient_attention
def xformers_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
dropout: float = 0.0, # pylint: disable=unused-argument
scaling: Optional[float] = None, # pylint: disable=unused-argument
sliding_window: Optional[int] = None, # pylint: disable=unused-argument
softcap: Optional[float] = None, # pylint: disable=unused-argument
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: Optional[int] = None,
max_length_k: Optional[int] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
# Get dimensions
# query: [batch, heads, seq_len, hidden_dim]
batch_size = query.size(0)
query_length = query.shape[2]
key_length = key.shape[2]
# Default causal mask
attn_bias = xformers.ops.LowerTriangularMask()
# Check if we have sliding window attention
has_sliding_window = sliding_window is not None and sliding_window < query_length
# Transpose dimensions for xformers (Q: [b, h, s, d] -> [b, s, h, d])
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Get GQA parameters
num_attention_heads = module.config.num_attention_heads
num_key_value_heads = module.config.num_key_value_heads
head_dim = query.size(-1)
is_gqa = num_attention_heads != num_key_value_heads
n_groups = num_attention_heads // num_key_value_heads if is_gqa else 1
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
if position_ids is not None and (
max_length_q is not None
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
):
if cu_seq_lens_q is None or cu_seq_lens_k is None:
cu_seq_lens_q = get_cu_seqlens_from_pos_ids(position_ids)[0]
cu_seq_lens_q = cu_seq_lens_q.squeeze()
seq_lengths = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
attn_bias = (
xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=seq_lengths.tolist(),
)
)
else:
query = query.reshape(-1, query.size(-2), query.size(-1))
key = key.reshape(-1, key.size(-2), key.size(-1))
value = value.reshape(-1, value.size(-2), value.size(-1))
# Handle GQA
if is_gqa:
key = key.repeat_interleave(n_groups, dim=2)
value = value.repeat_interleave(n_groups, dim=2)
elif attention_mask is not None:
query, key, value, _, cu_seq_lens, _ = _upad_input(
query, key, value, attention_mask, query_length
)
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
seq_lengths = []
for i in range(len(cu_seq_lens_q) - 1):
seq_lengths.append(cu_seq_lens_q[i + 1] - cu_seq_lens_q[i])
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=seq_lengths,
kv_seqlen=seq_lengths,
)
# Handle GQA
if is_gqa:
key = key.repeat_interleave(n_groups, dim=2)
value = value.repeat_interleave(n_groups, dim=2)
else:
# Handle Group Query Attention (GQA) using view/expand approach from reference
key = key.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
value = value.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
key = key.expand(
batch_size, key_length, num_key_value_heads, n_groups, head_dim
)
value = value.expand(
batch_size, key_length, num_key_value_heads, n_groups, head_dim
)
if module.training:
key = key.reshape(batch_size, key_length, num_attention_heads, head_dim)
value = value.reshape(batch_size, key_length, num_attention_heads, head_dim)
if has_sliding_window:
query = query.view(
1, batch_size * query_length, num_attention_heads, head_dim
)
key = key.view(
1, batch_size * key_length, num_attention_heads, head_dim
)
value = value.view(
1, batch_size * key_length, num_attention_heads, head_dim
)
else:
query = query.view(
batch_size, query_length, num_key_value_heads, n_groups, head_dim
)
# If we need a sliding window attention
if has_sliding_window:
query = query.view(
1,
batch_size * query_length,
num_key_value_heads,
n_groups,
head_dim,
)
key = key.view(
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
)
value = value.view(
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
)
# Run the xformers attention
attn_output = xformers_attention(
query,
key,
value,
attn_bias=attn_bias,
)
attn_output = attn_output.view(
batch_size, -1, attn_output.size(-2), attn_output.size(-1)
)
return attn_output, None

View File

@@ -23,22 +23,42 @@ from axolotl.utils.dict import DictDefault
LOG = get_logger(__name__)
ORIGINAL_QKV_CODE = """
QKV_PATCHES = [
(
"""
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
PATCHED_QKV_CODE = """
"\n"
),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape).transpose(1, 2)
key_states = key_states.view(hidden_shape).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
"\n"
),
),
(
"""
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
),
),
]
ORIGINAL_O_CODE = """
attn_output = self.o_proj(attn_output)
@@ -128,10 +148,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = __import__(
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
model_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")]
)
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
return attention_cls
except (ImportError, AttributeError) as e:
@@ -168,10 +189,18 @@ def patch_self_attn_lora(cfg: DictDefault):
attention_cls._original_forward = self_attn_forward
self_attn_forward, _ = detab_code(self_attn_forward)
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
assert any(
qkv_options[0] in self_attn_forward for qkv_options in QKV_PATCHES
), "Original QKV code not found"
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
for qkv_orig, qkv_patched in QKV_PATCHES:
if qkv_orig in self_attn_forward:
self_attn_forward = self_attn_forward.replace(
qkv_orig,
qkv_patched,
)
break
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace(
"def forward(",

View File

@@ -18,6 +18,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mixtral",
"qwen2",
"qwen2_moe",
"qwen3",
"qwen3_moe",
"falcon",
"phi",
"phi3",

View File

View File

@@ -0,0 +1,78 @@
"""
Patch prepare_model_for_kbit_training to not upcast everything
"""
import inspect
import logging
import peft
import axolotl
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger(__name__)
ORIGINAL_PREPARE_CODE = """
for param in model.parameters():
if (
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
) and param.__class__.__name__ != "Params4bit":
param.data = param.data.to(torch.float32)
"""
PATCHED_PREPARE_CODE = """
for name, param in model.named_parameters():
if (
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
) and param.__class__.__name__ != "Params4bit" and all(embed_name not in name for embed_name in ["embed_tokens", "lm_head"]):
param.data = param.data.to(torch.float32)
"""
def get_peft_prep_code() -> str:
prepare = inspect.getsource(peft.utils.other.prepare_model_for_kbit_training)
return prepare
def check_peft_prep_code_is_patchable() -> bool:
prep_code = get_peft_prep_code()
prep_code, _ = detab_code(prep_code)
return ORIGINAL_PREPARE_CODE in prep_code
def patch_peft_prep_code():
"""
monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs
"""
try:
prep_code = get_peft_prep_code()
except OSError:
return
peft.utils.other._original_create_accelerator_and_postprocess = ( # pylint: disable=protected-access
prep_code
)
prep_code, _ = detab_code(prep_code)
if ORIGINAL_PREPARE_CODE not in prep_code:
return
prep_code = prep_code.replace(ORIGINAL_PREPARE_CODE, PATCHED_PREPARE_CODE)
prep_code = prep_code.replace(
"def prepare_model_for_kbit_training(",
"def fixed_prepare_model_for_kbit_training(",
1,
)
items_to_import = []
for item in dir(peft.utils.other):
if item in prep_code:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from peft.utils.other import (" + ", ".join(x for x in items_to_import) + ")",
globals(),
)
exec(prep_code, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching prepare_model_for_kbit_training to allow for overrides")
peft.utils.other.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821
axolotl.utils.models.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821

View File

@@ -0,0 +1,42 @@
"""
monkeypatch for Trainer _get_learning_rate method
"""
import logging
import torch
LOG = logging.getLogger(__name__)
# TODO remove this patch once https://github.com/huggingface/transformers/pull/37881 is included in a release
def _get_learning_rate(self):
if self.is_deepspeed_enabled:
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
# not run for the first few dozen steps while loss scale is too large, and thus during
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
try:
last_lr = self.lr_scheduler.get_last_lr()[0]
except AssertionError as e:
if "need to call step" in str(e):
LOG.warning(
"tried to get lr value before scheduler/optimizer started stepping, returning lr=0"
)
last_lr = 0
else:
raise
else:
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
last_lr = self.optimizer.param_groups[0]["lr"]
else:
last_lr = self.lr_scheduler.get_last_lr()[0]
if torch.is_tensor(last_lr):
last_lr = last_lr.item()
return last_lr
def patch_trainer_get_lr():
from transformers.trainer import Trainer
Trainer._get_learning_rate = _get_learning_rate # pylint: disable=protected-access

View File

@@ -4,7 +4,7 @@ HF Chat Templates prompt strategy
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Set, Union
from pydantic import BaseModel
from transformers import ProcessorMixin
@@ -29,12 +29,12 @@ class ChatTemplatePrompter(Prompter):
chat_template: str,
processor=None,
max_length=2048,
message_property_mappings: Optional[Dict[str, str]] = None,
message_field_training: Optional[str] = None,
message_field_training_detail: Optional[str] = None,
message_property_mappings: Dict[str, str] | None = None,
message_field_training: str | None = None,
message_field_training_detail: str | None = None,
field_messages: str = "messages",
field_system: str = "system",
roles: Optional[Dict[str, List[str]]] = None,
roles: Dict[str, List[str]] | None = None,
drop_system_message: bool = False,
):
# check if message_property_mappings is None or empty dict
@@ -42,6 +42,7 @@ class ChatTemplatePrompter(Prompter):
message_property_mappings = {
"role": "role",
"content": "content",
"reasoning_content": "reasoning_content",
}
if roles:
@@ -65,7 +66,7 @@ class ChatTemplatePrompter(Prompter):
self.field_messages = field_messages
self.field_system = field_system
self.tokenizer = tokenizer
self.processor: Optional[ProcessorMixin] = processor
self.processor: ProcessorMixin | None = processor
self.chat_template = chat_template
self.max_length = max_length
self.drop_system_message = drop_system_message
@@ -224,11 +225,11 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
tokenizer,
train_on_inputs: bool,
sequence_len: int,
roles_to_train: Optional[List[str]] = None,
train_on_eos: Optional[str] = None,
train_on_eot: Optional[str] = None,
eot_tokens: Optional[List[str]] = None,
split_thinking: Optional[bool] = False,
roles_to_train: list[str] | None = None,
train_on_eos: str | None = None,
train_on_eot: str | None = None,
eot_tokens: list[str] | None = None,
split_thinking: bool | None = False,
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.prompter: ChatTemplatePrompter = prompter
@@ -661,16 +662,46 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
# if the role is assistant that we want to use reasoning_content
if self.split_thinking and transformed_message["role"] == "assistant":
content = transformed_message["content"]
pairs = [("<think>", "</think>"), ("<reasoning>", "</reasoning>")]
for pair in pairs:
if pair[0] in content and pair[1] in content:
start_idx = content.find(pair[0])
end_idx = content.find(pair[1])
thinking_content = content[start_idx + len(pair[0]) : end_idx]
thinking_pairs = [
("<think>", "</think>"),
("<reasoning>", "</reasoning>"),
("<|begin_of_thought|>", "<|end_of_thought|>"),
]
content_pairs = [("<|begin_of_solution|>", "<|end_of_solution|>")]
for tpair in thinking_pairs:
# check if the thinking pair is in the content
if tpair[0] in content and tpair[1] in content:
# find the start and end index of the thinking pair
t_start_idx = content.find(tpair[0])
t_end_idx = content.find(tpair[1])
# get the thinking content
thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]
transformed_message["reasoning_content"] = thinking_content.strip()
transformed_message["content"] = content[
end_idx + len(pair[1]) :
].lstrip()
# take remainder of the content
# strip whitespace from beginning of the remainder (thinking tokens)
remainder = content[t_end_idx + len(tpair[1]) :].lstrip()
# check if the content pair is in the remainder
cpair_found = False
for cpair in content_pairs:
if cpair[0] in remainder and cpair[1] in remainder:
# find the start and end index of the content pair
c_start_idx = remainder.find(cpair[0])
c_end_idx = remainder.find(cpair[1])
# get the content content
content_content = remainder[
c_start_idx + len(cpair[0]) : c_end_idx
]
transformed_message["content"] = content_content.strip()
cpair_found = True
break
# else, the content is the remainder
if not cpair_found:
transformed_message["content"] = remainder
break
# Determine which keys in the original message were not mapped
@@ -714,7 +745,7 @@ class StrategyLoader:
self,
tokenizer,
cfg,
ds_cfg: Optional[Union[Dict[str, Any], DatasetConfig]] = None,
ds_cfg: Union[Dict[str, Any], DatasetConfig] | None = None,
processor=None,
):
if ds_cfg is None:

View File

@@ -2,6 +2,7 @@
import importlib
import inspect
import logging
import os
import signal
import sys
@@ -12,7 +13,6 @@ from typing import Any, Dict
import torch
import transformers.modelcard
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
from datasets import Dataset
from huggingface_hub.errors import OfflineModeIsEnabled
@@ -21,6 +21,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
from axolotl.cli.art import print_axolotl_text_art
from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
@@ -30,7 +31,6 @@ from axolotl.core.trainers.mixins.sequence_parallel import (
SequenceParallelContextManager,
)
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
@@ -42,8 +42,7 @@ try:
except ImportError:
BetterTransformer = None
configure_logging()
LOG = get_logger(__name__)
LOG = logging.getLogger(__name__)
def setup_model_and_tokenizer(
@@ -64,7 +63,6 @@ def setup_model_and_tokenizer(
# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
@@ -503,6 +501,8 @@ def train(
Returns:
Tuple of (model, tokenizer) after training
"""
print_axolotl_text_art()
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
(
trainer,
@@ -512,6 +512,9 @@ def train(
processor,
) = setup_model_and_trainer(cfg, dataset_meta)
plugin_manager = PluginManager.get_instance()
plugin_manager.post_trainer_create(cfg, trainer)
# Handle untrained tokens if configured
safe_serialization = cfg.save_safetensors is True
train_dataset = dataset_meta.train_dataset
@@ -534,7 +537,6 @@ def train(
if not cfg.use_ray:
cleanup_distributed()
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train(cfg, model)
return model, tokenizer, trainer

View File

@@ -43,3 +43,12 @@ def set_pytorch_cuda_alloc_conf():
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
"expandable_segments:True,roundup_power2_divisions:16"
)
def patch_optimized_env():
"""
Patch environment variables to improve VRAM usage and increase download speed
"""
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
set_pytorch_cuda_alloc_conf()

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import gc
import json
import logging
import os
import traceback
@@ -808,11 +809,44 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
artifact.add_file(temp_file.name)
wandb.log_artifact(artifact)
wandb.save(temp_file.name)
LOG.info(
"The Axolotl config has been saved to the WandB run under files."
)
LOG.info(
"The Axolotl config has been saved to the WandB run under files."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
if args.deepspeed:
try:
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.
with NamedTemporaryFile(
mode="w",
delete=False,
suffix=".json",
prefix="deepspeed_config_",
) as temp_file:
skip_upload = False
if isinstance(args.deepspeed, dict):
json.dump(args.deepspeed, temp_file, indent=4)
elif isinstance(args.deepspeed, str) and os.path.exists(
args.deepspeed
):
copyfile(args.deepspeed, temp_file.name)
else:
skip_upload = True
if not skip_upload:
artifact = wandb.Artifact(
f"deepspeed-config-{wandb.run.id}",
type="deepspeed-config",
)
artifact.add_file(temp_file.name)
wandb.log_artifact(artifact)
wandb.save(temp_file.name)
LOG.info(
"The DeepSpeed config has been saved to the WandB run under files."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving DeepSpeed config to WandB: {err}")
return control
@@ -834,3 +868,28 @@ class GCCallback(TrainerCallback):
):
torch.cuda.empty_cache()
gc.collect()
def colab_inference_post_train_callback(trainer: Trainer):
class ColabCallback(TrainerCallback):
"""Callback to prep model for inference on Google Colab"""
def __init__(self, cfg):
self.gpu_name = torch.cuda.get_device_name(0)
self.cfg = cfg
def on_train_end(
self, args, state, control, **kwargs
): # pylint: disable=unused-argument
"""
handle T4 gpu, we need to convert attention to eager for inference
"""
if "Tesla T4" in self.gpu_name and self.cfg.xformers_attention:
trainer.model.config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
trainer.model.gradient_checkpointing_disable()
trainer.model.config.use_cache = True
trainer.model.eval()
return ColabCallback

View File

@@ -59,7 +59,7 @@ def choose_device(cfg):
def resolve_dtype(cfg):
if (
cfg.bf16 == "auto" and not cfg.use_ray
not cfg.fp16 and cfg.bf16 == "auto" and not cfg.use_ray
): # if we use ray we want to defer this check to the worker node
if is_torch_bf16_gpu_available():
LOG.debug("bf16 support detected, enabling for this configuration.")
@@ -67,9 +67,12 @@ def resolve_dtype(cfg):
else:
LOG.debug("bf16 support not detected, disabling for this configuration.")
cfg.bf16 = False
if cfg.fp16 is None:
if cfg.fp16 is None and not cfg.float16:
cfg.fp16 = True
if cfg.fp16 and cfg.bf16 == "auto":
cfg.bf16 = False
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False

View File

@@ -281,6 +281,10 @@ def load_dataset_w_config(
**load_ds_kwargs,
)
if not ds:
raise ValueError("unhandled dataset load")
raise ValueError(
"The dataset could not be loaded. This could be due to a misconfigured dataset path "
f"({config_dataset.path}). Try double-check your path / name / data_files. "
"This is not caused by the dataset type."
)
return ds

View File

@@ -69,17 +69,27 @@ def barrier():
dist.barrier()
def is_main_process():
def is_main_process(use_environ=False):
"""
Check if the current process is the main process. If not in distributed mode,
always return `True`.
Args:
- use_environ (bool, optional): Use environment variable to determine main process.
Returns:
- bool: `True` if the current process is the main process, `False` otherwise.
"""
if use_environ:
return os.environ.get("LOCAL_RANK", "0") == "0"
if not is_distributed():
return True
return dist.get_rank() == 0
def is_local_main_process():
def is_local_main_process(use_environ=False):
if use_environ:
return os.environ.get("LOCAL_RANK", "0") == "0"
return PartialState().is_local_main_process
@@ -99,17 +109,6 @@ def cleanup_distributed():
torch.distributed.destroy_process_group()
@contextmanager
def zero_only():
"""
Context manager that only runs the enclosed block on the main rank.
"""
if is_main_process():
yield
else:
yield None
@contextmanager
def zero_first(is_main):
"""

View File

@@ -1,16 +1,59 @@
"""custom checkpointing utils"""
import importlib
from functools import partial
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
from packaging import version
from axolotl.utils.gradient_checkpointing.offload_cpu import (
CPU_Offloaded_Gradient_Checkpointer,
)
from axolotl.utils.gradient_checkpointing.offload_disk import (
Disco,
)
transformers_version = version.parse(importlib.metadata.version("transformers"))
if transformers_version > version.parse("4.51.3"):
from transformers.modeling_layers import GradientCheckpointingLayer
def uses_gc_layers(decoder_layer):
return isinstance(decoder_layer.func.__self__, GradientCheckpointingLayer)
else:
def uses_gc_layers(_):
return False
def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
if uses_gc_layers(decoder_layer):
return CPU_Offloaded_Gradient_Checkpointer.apply(
decoder_layer,
*args,
)
return CPU_Offloaded_Gradient_Checkpointer.apply(
(
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)
else decoder_layer.__self__
),
*args,
)
def hf_grad_checkpoint_disk_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
if uses_gc_layers(decoder_layer):
return Disco.apply(
decoder_layer,
*args,
)
return Disco.apply(
(
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)

View File

@@ -1,4 +1,4 @@
"""Unsloth checkpointing"""
"""CPU offloaded checkpointing"""
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
@@ -26,7 +26,7 @@ else:
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
torch.autograd.Function
):
"""

View File

@@ -0,0 +1,531 @@
"""
DISCO - DIsk-based Storage and Checkpointing with Optimized prefetching
"""
# Copyright 2025 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
import concurrent.futures
import logging
import os
import queue
import shutil
import tempfile
import threading
import time
import uuid
from collections import deque
from concurrent.futures import Future
from typing import Dict
import torch
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
# Setup logger
logger = logging.getLogger(__name__)
class DiskOffloadManager:
"""
Manages offloaded tensors and handles prefetching in a separate thread.
Includes synchronization to prevent race conditions.
"""
def __init__(
self,
prefetch_size: int = 3,
prefetch_to_gpu: bool = True,
save_workers: int = 4,
):
"""
Args:
prefetch_size: Maximum number of tensors to prefetch in the background.
prefetch_to_gpu: Whether to prefetch tensors directly to GPU memory.
save_workers: Maximum number of concurrent save operations.
"""
self.temp_dir = tempfile.mkdtemp(prefix="disco_")
# Track tensor paths and their status
self.tensor_paths: deque = deque() # Ordered history of tensor paths (LIFO)
self.file_locks: Dict[str, threading.Lock] = (
{}
) # Maps file_path -> threading.Lock()
# Maps file_path -> status ("saving", "ready", "prefetching", "loaded", "deleted")
self.file_status: Dict[str, str] = {}
self.max_prefetch = prefetch_size
self.prefetch_to_gpu = prefetch_to_gpu
# Thread synchronization
self.manager_lock = threading.RLock() # Used for thread-safe operations
# Prefetch queue and cache
self.prefetch_queue: queue.Queue = queue.Queue()
self.prefetch_cache: Dict[str, torch.Tensor] = {} # Maps file_path -> tensor
# Save queue and thread pool
self.save_queue: queue.Queue = queue.Queue()
self.save_pool = concurrent.futures.ThreadPoolExecutor(max_workers=save_workers)
self.save_futures: Dict[str, Future] = {}
self.save_semaphore = threading.Semaphore(
save_workers * 2
) # Limit concurrent save operations
# Start prefetch worker thread
self.stop_event = threading.Event()
# start multiple threads for prefetching
self.prefetch_worker_count = 2
self.prefetch_workers = []
for _ in range(self.prefetch_worker_count):
worker = threading.Thread(target=self._prefetch_worker, daemon=True)
worker.start()
self.prefetch_workers.append(worker)
# Start save worker thread
self.save_worker = threading.Thread(target=self._save_worker, daemon=True)
self.save_worker.start()
self.idx = 0
atexit.register(self.cleanup)
def _save_worker(self):
"""Background thread that processes the save queue"""
while not self.stop_event.is_set():
try:
save_item = self.save_queue.get(timeout=0.5)
if save_item is None:
continue
tensor, file_path = save_item
# Submit the save task to the thread pool
future = self.save_pool.submit(
self._save_tensor_to_disk, tensor, file_path
)
with self.manager_lock:
self.save_futures[file_path] = future
self.save_queue.task_done()
except queue.Empty:
time.sleep(0.01) # Small sleep to prevent CPU spinning
continue
def _save_tensor_to_disk(self, tensor: torch.Tensor, file_path: str):
"""Actually save the tensor to disk"""
try:
# Save tensor to disk
cpu_tensor = tensor.detach().cpu()
torch.save(cpu_tensor, file_path)
del cpu_tensor
with self.manager_lock:
# Mark file as ready
self.file_status[file_path] = "ready"
# Release semaphore
self.save_semaphore.release()
return True
except FileNotFoundError as e:
logger.error(f"Error saving tensor to {file_path}: {e}")
with self.manager_lock:
self.file_status[file_path] = "error"
# Release semaphore
self.save_semaphore.release()
return False
def _prefetch_worker(self):
"""Background thread that loads tensors from disk ahead of time"""
while not self.stop_event.is_set():
try:
file_path = self.prefetch_queue.get(timeout=0.5)
if file_path is None:
continue
# Check if file is available and not already in cache
with self.manager_lock:
if (
file_path not in self.file_status
or self.file_status[file_path] == "deleted"
):
self.prefetch_queue.task_done()
if file_path in self.prefetch_cache:
self.prefetch_queue.task_done()
continue
# If file is still being saved, wait for it
if (
self.file_status[file_path] == "saving"
and file_path in self.save_futures
):
# Re-queue this prefetch request with a little delay
self.prefetch_queue.task_done()
time.sleep(0.1)
self.prefetch_queue.put(file_path)
continue
# Mark file as being prefetched
self.file_status[file_path] = "prefetching"
# Load tensor from disk and store in cache
try:
if os.path.exists(file_path):
if self.prefetch_to_gpu:
tensor = torch.load(
file_path,
map_location=torch.device("cuda"),
weights_only=True,
)
else:
tensor = torch.load(file_path, weights_only=True)
with self.manager_lock:
self.prefetch_cache[file_path] = tensor
self.file_status[file_path] = "ready"
else:
with self.manager_lock:
if self.file_status.get(file_path) != "deleted":
logger.warning(
f"Prefetch error: File not found {file_path}"
)
self.file_status[file_path] = "missing"
except FileNotFoundError as e:
with self.manager_lock:
if self.file_status.get(file_path) != "deleted":
logger.warning(f"Prefetch error for {file_path}: {e}")
self.file_status[file_path] = "error"
self.prefetch_queue.task_done()
except queue.Empty:
time.sleep(0.01) # Small sleep to prevent CPU spinning
continue
def save_tensor(self, tensor: torch.Tensor):
"""Save tensor to disk asynchronously and return file path with thread-safe operations"""
# Generate unique file path
self.idx += 1
file_path: str = os.path.join(
self.temp_dir, f"{self.idx:06d}-{uuid.uuid4()}.pt"
)
with self.manager_lock:
# Mark file as being saved
self.file_locks[file_path] = threading.Lock()
self.file_status[file_path] = "saving"
# Add to history
self.tensor_paths.append(file_path)
# Acquire semaphore to limit concurrent save operations
self.save_semaphore.acquire() # pylint: disable=consider-using-with
# Queue tensor for saving in background
self.save_queue.put((tensor.detach(), file_path))
return file_path
def wait_for_save(self, file_path, timeout=None) -> None:
"""Wait for a tensor to be saved to disk"""
start_time = time.time()
while timeout is None or time.time() - start_time < timeout:
with self.manager_lock:
if self.file_status.get(file_path) == "ready":
return
if self.file_status.get(file_path) in ["error", "missing", "deleted"]:
return
if file_path in self.save_futures:
future = self.save_futures[file_path]
if future.done():
return
# Small sleep to prevent CPU spinning
time.sleep(0.01)
# Timeout
logger.warning(f"Timeout waiting for tensor to be saved: {file_path}")
return
def load_tensor(self, file_path, target_device="cuda"):
"""Load tensor from disk or prefetch cache with proper synchronization"""
# Wait for tensor to be saved if it's still in progress
self.wait_for_save(file_path)
tensor = None
# Try to get from cache first
with self.manager_lock:
# Check if tensor is already in cache
if file_path in self.prefetch_cache:
tensor = self.prefetch_cache[file_path]
del self.prefetch_cache[file_path]
self.file_status[file_path] = "loaded"
if tensor is not None:
# Ensure tensor is on correct device
if target_device != "cpu" and tensor.device.type == "cpu":
tensor = tensor.to(target_device, non_blocking=True)
return tensor
# If not in cache, load directly from disk
try:
if not os.path.exists(file_path):
logger.error(f"File not found for loading: {file_path}")
raise FileNotFoundError(f"File not found: {file_path}")
tensor = torch.load(file_path, weights_only=True)
with self.manager_lock:
self.file_status[file_path] = "loaded"
if target_device != "cpu":
tensor = tensor.to(target_device, non_blocking=True)
return tensor
except Exception as e:
logger.error(f"Error loading tensor from {file_path}: {e}")
raise
def _safe_delete_file(self, file_path):
"""Safely delete a file with proper synchronization"""
with self.manager_lock:
# Make sure any save operation is completed
if file_path in self.save_futures:
future = self.save_futures[file_path]
try:
if not future.done():
future.cancel()
del self.save_futures[file_path]
except FileNotFoundError as e:
logger.warning(
f"Error canceling save operation for {file_path}: {e}"
)
# Only delete if file exists and is not being prefetched
status = self.file_status.get(file_path)
if status in ["ready", "loaded", "error", "missing"]:
try:
if os.path.exists(file_path):
os.remove(file_path)
self.file_status[file_path] = "deleted"
return True
except FileNotFoundError as e:
logger.warning(f"Error deleting file {file_path}: {e}")
return False
def trigger_prefetch(self, n=None):
"""Trigger prefetching of the next N tensors with proper synchronization"""
if n is None:
n = self.max_prefetch
prefetch_paths = []
with self.manager_lock:
# Find files that are ready to be prefetched (not already in cache or being prefetched)
for path in reversed(self.tensor_paths):
if (
path not in self.prefetch_cache
and self.file_status.get(path) == "ready"
):
prefetch_paths.append(path)
if len(prefetch_paths) >= n:
break
# Queue files for prefetching
for path in prefetch_paths:
self.prefetch_queue.put(path)
def cleanup_tensor(self, file_path: str):
"""Clean up a specific tensor file after it's been used"""
with self.manager_lock:
if file_path in self.tensor_paths:
self.tensor_paths.remove(file_path)
# Remove from prefetch cache if present
if file_path in self.prefetch_cache:
del self.prefetch_cache[file_path]
# Remove from save futures if present
if file_path in self.save_futures:
future = self.save_futures[file_path]
if not future.done():
future.cancel()
del self.save_futures[file_path]
# Try to delete the file
self._safe_delete_file(file_path)
def cleanup(self):
"""Clean up all temp files and stop prefetch thread with proper synchronization"""
self.stop_event.set()
# Cancel all pending save operations
with self.manager_lock:
for _, future in self.save_futures.items():
if not future.done():
future.cancel()
self.save_futures.clear()
# Drain the save queue
while not self.save_queue.empty():
try:
self.save_queue.get_nowait()
self.save_queue.task_done()
except queue.Empty:
break
# Shutdown the save pool
self.save_pool.shutdown(wait=False)
# Join the save worker thread
if self.save_worker.is_alive():
self.save_worker.join(timeout=2.0)
# Join the prefetch worker threads
for thread in self.prefetch_workers:
if thread.is_alive():
thread.join(timeout=2.0)
# Clear cache and remove all temporary files
with self.manager_lock:
self.prefetch_cache.clear()
paths_to_delete = list(self.tensor_paths)
self.tensor_paths.clear()
# Delete all temporary files
for path in paths_to_delete:
self._safe_delete_file(path)
# Remove temp directory
try:
if os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir, ignore_errors=True)
except FileNotFoundError as e:
logger.warning(f"Error removing temporary directory {self.temp_dir}: {e}")
class Disco(torch.autograd.Function):
"""
Disco: DIsk-based Storage and Checkpointing with Optimized prefetching
Advanced disk-based gradient checkpointer with prefetching.
"""
# Shared manager instance across all checkpointing operations
_manager = None
@staticmethod
def get_instance(prefetch_size=1, prefetch_to_gpu=True, save_workers=4):
"""Get or create the offload manager"""
if Disco._manager is None:
Disco._manager = DiskOffloadManager(
prefetch_size=prefetch_size,
prefetch_to_gpu=prefetch_to_gpu,
save_workers=save_workers,
)
return Disco._manager
@staticmethod
@torch_cuda_amp_custom_fwd
def forward(
ctx,
forward_function,
hidden_states,
*args,
prefetch_size=1,
prefetch_to_gpu=True,
save_workers=4,
):
"""Forward pass that offloads activations to disk asynchronously"""
# Get or create the manager
manager = Disco.get_instance(
prefetch_size=prefetch_size,
prefetch_to_gpu=prefetch_to_gpu,
save_workers=save_workers,
)
# Save tensor to disk asynchronously
file_path = manager.save_tensor(hidden_states)
# Run forward pass immediately without waiting for save to complete
with torch.no_grad():
output = forward_function(hidden_states, *args)
# Store what we need for backward
ctx.save_for_backward(torch.tensor([0])) # Dummy tensor
ctx.file_path = file_path
ctx.forward_function = forward_function
ctx.args = args
return output
@staticmethod
@torch_cuda_amp_custom_bwd
def backward(ctx, *grad_outputs):
"""Backward pass that loads activations from disk with prefetching"""
# Get the manager
manager = Disco._manager
# Trigger prefetching for future tensors
# This happens at the start of backward, so should have time to complete
manager.trigger_prefetch()
# Load hidden states from disk or prefetch cache
file_path = ctx.file_path
try:
# Ensure the file is saved before we try to load it
manager.wait_for_save(file_path)
hidden_states = manager.load_tensor(file_path)
hidden_states.requires_grad = True
# Compute gradients
with torch.enable_grad():
output = ctx.forward_function(hidden_states, *ctx.args)
# Handle tuple outputs properly
if isinstance(output, tuple):
if len(grad_outputs) == len(output):
torch.autograd.backward(output, grad_outputs)
else:
torch.autograd.backward(output, grad_outputs[0])
else:
torch.autograd.backward(output, grad_outputs[0])
# Clean up the file after we're done with it
manager.cleanup_tensor(file_path)
return (
(
None, # forward_function
hidden_states.grad, # hidden_states grad
)
+ (None,) * len(ctx.args) # for each arg
+ (
None, # prefetch_size
None, # prefetch_to_gpu
None, # save_workers
)
)
except Exception as e:
logger.error(f"Error in backward pass: {e}")
# Clean up the file even on error
manager.cleanup_tensor(file_path)
raise

View File

@@ -68,9 +68,12 @@ from axolotl.utils.distributed import (
get_device_count,
get_device_type,
is_local_main_process,
zero_only,
is_main_process,
)
from axolotl.utils.gradient_checkpointing import (
hf_grad_checkpoint_disk_offload_wrapper,
hf_grad_checkpoint_offload_wrapper,
)
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
@@ -437,7 +440,7 @@ def load_tokenizer(cfg):
{"additional_special_tokens": additional_special_tokens}
)
with zero_only():
if is_main_process(use_environ=True):
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
@@ -540,11 +543,21 @@ class ModelLoader:
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
def apply_patches(self) -> None:
if self.cfg.xformers_attention and self.cfg.sample_packing:
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
patch_xformers_attn_over_fa2()
self.cfg.flash_attention = True
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
patch_accelerate_fsdp_utils()
if self.cfg.adapter and self.cfg.embeddings_skip_upcast:
from axolotl.monkeypatch.peft.utils import patch_peft_prep_code
patch_peft_prep_code()
if self.cfg.flex_attention:
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_make_mask,
@@ -593,6 +606,10 @@ class ModelLoader:
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
if self.cfg.gradient_checkpointing == "offload_disk":
transformers.modeling_utils.checkpoint = (
hf_grad_checkpoint_disk_offload_wrapper
)
if self.cfg.flash_attention:
self.patch_attention()
@@ -1164,7 +1181,7 @@ class ModelLoader:
],
)
def prepare_model(self, qlora_fsdp) -> None:
def prepare_model(self, qlora_fsdp: bool) -> None:
skip_prepare_model_for_kbit_training = False
if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
@@ -1294,7 +1311,10 @@ class ModelLoader:
# make sure these are fp32 per Ramesh et al. (2021)
embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
if not self.cfg.fsdp:
# FSDP doesn't like mixed Float and BFloat16
# we don't run this during FSDP because this will leave mixed
# float and bfloat16 dtypes in the model which FSDP doesn't like
if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast:
embedding_modules = []
self.convert_embedding_modules_dtype(
embedding_modules,
dist_dtype=torch.float32,

View File

@@ -1,10 +1,13 @@
# pylint: skip-file
"""
Multipack Batch Sampler
Multipack Batch Sampler - An efficient batch sampler for packing variable-length sequences
into fixed-capacity batches to optimize memory usage and training throughput.
"""
import logging
import math
from typing import Any, Iterable, List, Union
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count, get_context
from typing import Iterable, Union
import numba
import numpy as np
@@ -13,26 +16,39 @@ from torch.utils.data import BatchSampler, Sampler, SequentialSampler
from axolotl.utils.distributed import reduce_and_broadcast
LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)
@numba.njit
def ffd_check(a: np.ndarray, c: int, n: int):
# First-fit-decreasing bin packing
# Check if a[] could fit in n bins with capacity c
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int):
"""
First-fit-decreasing bin packing algorithm check
a = np.sort(a)[::-1]
bins = np.full((n,), c, dtype=a.dtype)
for size in a:
Checks if sequences with the given lengths could fit in the specified number of bins
Args:
sequence_lengths: Array of sequence lengths
bin_capacity: Maximum capacity of each bin
num_bins: Number of bins available
Returns:
True if all sequences can be packed, False otherwise
"""
# Sort sequence lengths in descending order for optimal packing
sequence_lengths = np.sort(sequence_lengths)[::-1]
# Initialize all bins with full capacity
bins = np.full((num_bins,), bin_capacity, dtype=sequence_lengths.dtype)
# Try to place each sequence in the first bin it fits
for size in sequence_lengths:
not_found = True
for idx in range(n):
for idx in range(num_bins):
if bins[idx] >= size:
bins[idx] -= size
not_found = False
break
# If no bin could fit this sequence, packing failed
if not_found:
return False
@@ -40,86 +56,155 @@ def ffd_check(a: np.ndarray, c: int, n: int):
@numba.njit
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
# First-fit-decreasing bin packing (with result return)
def pack_group(
sequence_lengths: np.ndarray,
group_offset: int,
bin_capacity: int,
max_bins: int,
bin_size: int,
safe_mode: bool = True,
):
"""
Pack a group of sequences into bins using First-Fit Decreasing algorithm
indices = np.argsort(a)[::-1]
a = a[indices]
Args:
sequence_lengths: Array of sequence lengths
group_offset: Offset to apply to indices when returning results
bin_capacity: Maximum capacity of each bin
max_bins: Maximum number of bins to use
bin_size: Maximum number of sequences per bin
safe_mode: If True, use a more conservative packing approach
bins: List[Any] = []
bins_result: List[Any] = []
for a_id, size in enumerate(a):
add_new = True
for idx in range(len(bins)):
if bins[idx] >= size:
bins[idx] -= size
bins_result[idx].append(indices[a_id] + start_index)
add_new = False
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
bins_remaining_space: list = [] # Tracks remaining capacity in each bin
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
for seq_id, size in enumerate(sequence_lengths):
global_idx = seq_id + group_offset
# Try to place sequence in existing bins
add_new_bin = True
for bin_idx, _ in enumerate(bins_remaining_space):
if (
bins_remaining_space[bin_idx] >= size
and len(bins_assigned_sequences[bin_idx]) < bin_size
):
bins_remaining_space[bin_idx] -= size
bins_assigned_sequences[bin_idx].append(global_idx)
add_new_bin = False
break
if add_new:
bins.append(c - size)
bins_result.append([indices[a_id] + start_index])
# Create a new bin if needed and if we haven't reached the limit
if add_new_bin:
if len(bins_remaining_space) >= max_bins and safe_mode:
# In safe mode, skip items that would exceed max_bins
continue
bins_remaining_space.append(bin_capacity - size)
bins_assigned_sequences.append([global_idx])
return bins_result
# Safety check to avoid infinite bins
if len(bins_remaining_space) > len(sequence_lengths):
break
return bins_assigned_sequences
@numba.njit
def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
# Define a standalone function for multiprocessing
def _process_group(args):
group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode = args
return pack_group(
group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode
)
def pack_parallel(
sequence_lengths: np.ndarray,
bin_capacity: int,
group_size: int,
bin_size: int,
num_processes: int | None = None,
safe_mode: bool = True,
mp_start_method: str | None = "spawn",
):
# Dynamic batch allocator, similar to Multifit
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
"""
Pack sequences into bins using parallel processing
s = 0
start_index = 0
result = []
Args:
sequence_lengths: Array of sequence lengths
bin_capacity: Maximum capacity of each bin as total number of tokens
group_size: Number of sequences to process in each group
bin_size: Maximum number of bins to use
num_processes: Number of parallel processes to use
safe_mode: If True, use a more conservative packing approach
mp_start_method: Multiprocessing start method ('fork', 'spawn', 'forkserver').
'spawn' is often safer with Numba/PyTorch.
Set to None to use system default.
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
num_items = len(sequence_lengths)
if num_processes is None:
num_processes = max(1, min(num_items // group_size, cpu_count()))
while True:
# binary search [l, r)
left = 1
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
# Create tasks for parallel processing
tasks = []
for i in range(0, num_items, group_size):
group_lengths = sequence_lengths[i : i + group_size]
max_bins = len(group_lengths) # Allow as many bins as items in the group
tasks.append((group_lengths, i, bin_capacity, max_bins, bin_size, safe_mode))
while right - left > 1:
mid = (left + right) // 2
if ffd_check(lengths[start_index : start_index + mid], c, n):
left = mid
else:
right = mid
# Process groups in parallel
all_bins = []
# use length l
batch = ffd_with_result(
lengths[start_index : start_index + left], c, start_index
)
assert len(batch) <= n
if len(batch) < n:
break
mp_ctx = None
if mp_start_method:
try:
mp_ctx = get_context(mp_start_method)
except ValueError:
LOG.warning(
f"Failed to get multiprocessing context '{mp_start_method}'. "
f"Falling back to default. Available: {get_context().get_all_start_methods()}"
)
mp_ctx = (
None # Fallback to default context if specified one is not available
)
start_index += left
s = lengths_cumsum[start_index - 1]
if num_processes == 1:
LOG.debug("Using single process for pack_parallel, running sequentially.")
for task_args in tasks:
group_bins = _process_group(task_args)
all_bins.extend(group_bins)
else:
# Use ProcessPoolExecutor only if num_processes > 1
# Pass mp_context if available
with ProcessPoolExecutor(
max_workers=num_processes, mp_context=mp_ctx
) as executor:
for group_bins in executor.map(_process_group, tasks):
all_bins.extend(group_bins)
# add local rank
result.append(batch[rank])
return result, s, len(result) * c * n
return all_bins
@numba.njit
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
def allocate_sequentially(
sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int
):
"""
Sequential allocator that preserves example order
Parameters:
- lengths: The lengths of all examples
- rank: The current rank (for distributed training)
- c: The capacity of each bin (maximum sequence length)
- n: Number of ranks
Args:
sequence_lengths: The lengths of all examples
rank: The current rank (for distributed training)
bin_capacity: The capacity of each bin (maximum sequence length)
num_ranks: Number of ranks (processes/GPUs)
Returns:
- result: List of batches for the current rank
- total_used: Number of actual example tokens
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
rank_batches: List of batches for the current rank
total_tokens_used: Number of actual example tokens
total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
"""
result = []
total_used = 0
@@ -127,9 +212,9 @@ def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
# First, do sequential packing into bins
all_bins = []
current_bin = [0 for i in range(0)] # numba hint
remaining_capacity = c
remaining_capacity = bin_capacity
for idx, size in enumerate(lengths):
for idx, size in enumerate(sequence_lengths):
if size <= remaining_capacity:
# Example fits in current bin
current_bin.append(idx)
@@ -140,7 +225,7 @@ def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
if current_bin: # Add non-empty bin to all_bins
all_bins.append(current_bin)
current_bin = [idx]
remaining_capacity = c - size
remaining_capacity = bin_capacity - size
total_used += size
# Add the last bin if not empty
@@ -148,132 +233,227 @@ def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
all_bins.append(current_bin)
# Assign bins to ranks - each rank gets every n-th bin
for bin_idx in range(rank, len(all_bins), n):
for bin_idx in range(rank, len(all_bins), num_ranks):
result.append(all_bins[bin_idx])
return result, total_used, len(all_bins) * c
return result, total_used, len(all_bins) * bin_capacity
class MultipackBatchSampler(BatchSampler):
"""Batch sampler class for multipack"""
"""
Batch sampler class for efficient packing of variable-length sequences
This sampler packs sequences into fixed-capacity bins (batches) to maximize
GPU memory utilization and training throughput by reducing padding.
It supports both parallel packing (using FFD algorithm) and
sequential packing (preserving original sequence order).
"""
def __init__(
self,
sampler: Union[Sampler[int], Iterable[int]],
batch_size: int,
batch_max_len: int,
lengths: np.ndarray,
packing_efficiency_estimate: float = 1.0,
drop_last: bool = False,
num_count_samples: int = 16,
sequential: bool = False,
**kwargs,
batch_size: int, # Number of bins per batch
batch_max_len: int, # Maximum sequence length (bin capacity)
lengths: np.ndarray, # Sequence lengths
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
drop_last: bool = False, # Whether to drop final batches (might be incomplete)
num_count_samples: int = 16, # Number of times to estimate batch count
sequential: bool = False, # Whether to use sequential packing
group_size: int = 100_000, # Size of groups for parallel packing
bin_size: int = 200, # The max number of samples that can be packed in a single bin
num_processes: int | None = None, # Number of processes for parallel packing
safe_mode: bool = True, # Conservative packing to prevent training instability
**kwargs, # pylint: disable=unused-argument
):
super().__init__(sampler, batch_size, drop_last)
self.batch_size = batch_size
self.batch_max_len = batch_max_len
self.lengths: np.ndarray = lengths
self.lengths = np.array(lengths, dtype=np.int32)
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.sequential = sequential
self.group_size = group_size
self.bin_size = bin_size
self.num_processes = num_processes
self.safe_mode = safe_mode
assert isinstance(self.lengths, np.ndarray)
self.epoch = 0
# statistics
self.eff_total_used = 0
self.eff_total_slots = 0
# Efficiency statistics tracking
self.total_tokens_used = 0
self.total_token_slots = 0
# The number of times to calculate the batches to determine the minimum packed dataset length for the local rank
# The number of times to calculate batches to determine minimum packed dataset length
self.num_count_samples = num_count_samples
# the minimum packed dataset length across all ranks determined by a gather/broadcast
# Minimum packed dataset length across all ranks (determined by gather/broadcast)
self.len_across_ranks = None
# Cache for batches
self._batches = None
if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warn(
LOG.warning(
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
)
def set_epoch(self, epoch: int):
"""Set the epoch number, used for reproducible shuffling across epochs"""
self.epoch = epoch
self._batches = None # Invalidate batch cache
def generate_batches(self, set_stats=False):
indices = [idx for idx in self.sampler]
"""
Generate packed batches for training
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
Args:
set_stats: Whether to update efficiency statistics
if self.sequential:
batches, total_used, total_slots = allocate_sequentially(
lengths=lengths,
rank=0,
c=self.batch_max_len,
n=1,
)
else:
batches, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=0,
c=self.batch_max_len,
n=1,
)
Returns:
List of batches, where each batch contains multiple bins,
and each bin contains multiple sequence indices
"""
if self._batches is not None:
return self._batches
batches = [
[
[indices[b_idx] for b_idx in batch]
for batch in batches[i : i + self.batch_size]
]
for i in range(0, len(batches), self.batch_size)
# Get indices from the sampler
indices = [ # pylint: disable=unnecessary-comprehension
idx for idx in self.sampler
]
# statistics
if set_stats:
self.eff_total_used += total_used
self.eff_total_slots += total_slots
# Get lengths of the selected sequences
lengths = self.lengths[indices]
# Pack sequences into bins using either sequential or parallel packing
if self.sequential:
bins, total_used, total_slots = allocate_sequentially(
lengths,
rank=0,
bin_capacity=self.batch_max_len,
num_ranks=1,
)
# Map bin indices back to original indices
bins = [[indices[b_idx] for b_idx in bin_indices] for bin_indices in bins]
else:
# Use parallel packing
all_bins = pack_parallel(
lengths,
bin_capacity=self.batch_max_len,
group_size=self.group_size,
bin_size=self.bin_size,
num_processes=self.num_processes,
safe_mode=self.safe_mode,
)
# Map bin indices back to original indices
bins = [
[indices[b_idx] for b_idx in bin_indices] for bin_indices in all_bins
]
# Calculate efficiency statistics
total_used = lengths.sum()
total_slots = len(all_bins) * self.batch_max_len
# Group bins into batches (each batch contains batch_size bins)
batches = [
bins[i : i + self.batch_size] for i in range(0, len(bins), self.batch_size)
]
# Drop last batch if requested and it's incomplete
if self.drop_last and len(batches[-1]) < self.batch_size:
batches = batches[:-1]
# Adjust total_slots if we dropped a batch
if not self.sequential:
total_slots -= (self.batch_size - len(batches[-1])) * self.batch_max_len
# Update statistics if requested
if set_stats:
self.total_tokens_used += total_used
self.total_token_slots += total_slots
self._batches = batches
return batches
def __iter__(self):
"""
Return an iterator over batches
The batches are truncated to match the minimum number of batches across all ranks
to ensure distributed training balance
"""
batches = self.generate_batches(set_stats=True)
if self.len_across_ranks:
# make sure the batches we iterate over is truncated to the same min length across all ranks
# Truncate batches to ensure all ranks have the same number of batches
batches = batches[: self.len_across_ranks]
return iter(batches)
def num_batches(self):
batches = self.generate_batches(set_stats=True)
return len(batches)
def efficiency(self):
return self.eff_total_used / self.eff_total_slots
"""
Calculate the packing efficiency (ratio of tokens used to total token slots)
Higher is better - 1.0 would mean perfect packing with no wasted space
"""
if self.total_token_slots == 0:
self.generate_batches(set_stats=True)
if self.total_token_slots == 0:
return 0.0
# Return a Python float instead of potentially a numpy float
return float(self.total_tokens_used / self.total_token_slots)
def gather_efficiency(self):
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
return math.floor(0.997 * max(estimates))
"""
Gather and synchronize packing efficiency estimates across all distributed ranks
Returns a conservative efficiency estimate based on the measurements
"""
def calc_sample_packing_eff_est(estimates: list[float]):
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
# Use 99.7% of max observed efficiency as a safe estimate
max_eff = max(float(eff) for eff in estimates)
return math.floor(0.997 * max_eff)
# Gather efficiency from all ranks and apply the calculation function
sample_packing_actual_eff_all = reduce_and_broadcast(
lambda: self.efficiency(), # pylint: disable=unnecessary-lambda
lambda: float(self.efficiency()), # pylint: disable=unnecessary-lambda
calc_sample_packing_eff_est,
)
# Quantize to 0.5% intervals for stability
sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0
)
return sample_packing_eff_est
def gather_len_batches(self, num):
"""
Gather and synchronize batch counts across all distributed ranks
Returns the minimum number of batches available on any rank
"""
def calc_min_len(estimates: list[(int, float)]):
LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(min(estimates))
# Find minimum batch count across ranks to ensure balance
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
return min_len_batches
def __len__(self):
if not self.len_across_ranks:
len_batches = min(
[self.num_batches() for _ in range(self.num_count_samples)]
"""
Return the total number of batches that will be yielded by this sampler
This is calculated as the minimum number of batches available on any rank
to ensure balanced distributed training
"""
if self._batches is None:
self._batches = self.generate_batches(set_stats=True)
if self.len_across_ranks is None:
# Sample multiple times to get stable estimate
len_batches = min( # pylint: disable=consider-using-generator
[len(self._batches) for _ in range(self.num_count_samples)]
)
# Gather minimum across all ranks
self.len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks

View File

@@ -82,6 +82,7 @@ class AxolotlInputConfig(
mean_resizing_embeddings: bool | None = False
# optionally shrink the embeddings when the tokenizer vocab size is smaller
shrink_embeddings: bool | None = None
embeddings_skip_upcast: bool | None = None
rl: RLType | None = None
trl: TRLConfig | None = Field(
@@ -177,7 +178,7 @@ class AxolotlInputConfig(
# torch_dtype: torch.dtype | None
gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field(
gradient_checkpointing: Literal["offload", "offload_disk"] | bool | None = Field(
default=False
)
gradient_checkpointing_kwargs: dict[str, Any] | None = None
@@ -435,16 +436,6 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_xformers(cls, data):
if data.get("sample_packing") and data.get("xformers_attention"):
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
return data
@model_validator(mode="before")
@classmethod
# pylint: disable=duplicate-code
@@ -471,9 +462,10 @@ class AxolotlInputConfig(
and not data.get("flash_attention")
and not data.get("sdp_attention")
and not data.get("flex_attention")
and not data.get("xformers_attention")
):
LOG.warning(
"sample_packing without flash, sdp or flex attention does not handle cross sample decontamination."
"sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination."
)
return data
@@ -512,10 +504,17 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def hint_sample_packing_padding(cls, data):
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
if data.get("sample_packing"):
pad_to_sequence_len = data.get("pad_to_sequence_len")
if pad_to_sequence_len is False:
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
elif pad_to_sequence_len is None:
LOG.info(
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
)
data["pad_to_sequence_len"] = True
return data
@model_validator(mode="before")
@@ -1150,6 +1149,30 @@ class AxolotlInputConfig(
return data
# @model_validator(mode="before")
# @classmethod
# def check_grpo_peft_liger(cls, data):
# if (
# data.get("rl") == "grpo"
# and data.get("trl", {})
# and data.get("trl").get("use_liger_loss")
# and data.get("adapter")
# ):
# raise ValueError("PEFT + GRPO + Liger is not yet supported")
# return data
#
@model_validator(mode="before")
@classmethod
def check_grpo_liger_sequence_parallel(cls, data):
if (
data.get("rl") == "grpo"
and data.get("trl", {})
and data.get("trl").get("use_liger_loss")
and data.get("sequence_parallel_degree", 1) > 1
):
raise ValueError("GRPO + SP + Liger not currently supported")
return data
@model_validator(mode="after")
def check_sequence_parallel_degree(self):
if not self.sequence_parallel_degree:
@@ -1315,6 +1338,61 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return data
@model_validator(mode="before")
@classmethod
def check_auto_enable_lora_kernels(cls, data):
# Only proceed if using LoRA or QLoRA adapter
if data.get("rl"):
# RL trainers not tested so don't enable kernels by default
return data
if data.get("adapter") in ["lora", "qlora"]:
# Skip if already set, using unsloth optimizations, or using 8-bit
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
if (
any(data.get(k) is not None for k in kernel_fields)
or any(data.get(k) for k in unsloth_fields)
or data.get("adapter") == "lora"
and data.get("load_in_8bit")
):
return data
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
if data.get("lora_dropout") != 0:
return data
# Check multi-GPU compatibility
capabilities = data.get("capabilities")
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
is_fsdp = data.get("fsdp") is not None
is_fsdp2 = (
data.get("fsdp_config") is not None
and str(data.get("fsdp_config").get("fsdp_version")) == "2"
)
if (
not is_multi_gpu
or (is_multi_gpu and not is_fsdp)
or (is_multi_gpu and is_fsdp2)
):
# Auto-enable kernels if not explicitly set by user
if data.get("lora_mlp_kernel") is None:
data["lora_mlp_kernel"] = True
if data.get("lora_qkv_kernel") is None:
data["lora_qkv_kernel"] = True
if data.get("lora_o_kernel") is None:
data["lora_o_kernel"] = True
LOG.warning(
"Auto-enabling LoRA kernel optimizations for faster training. "
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "
+ "See https://docs.axolotl.ai/docs/lora_optims.html for more info."
)
return data
@model_validator(mode="before")
@classmethod
def check_adopt_torch_version(cls, data):

View File

@@ -53,4 +53,5 @@ class CustomSupportedOptimizers(str, Enum):
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
came_pytorch = "came_pytorch" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name

View File

@@ -75,8 +75,10 @@ class HyperparametersConfig(BaseModel):
lr_groups: list[LrGroup] | None = None
adam_epsilon: float | None = None
adam_epsilon2: float | None = None
adam_beta1: float | None = None
adam_beta2: float | None = None
adam_beta3: float | None = None
max_grad_norm: float | None = None
num_epochs: float = Field(default=1.0)

View File

@@ -67,6 +67,12 @@ class TRLConfig(BaseModel):
default=False,
json_schema_extra={"description": "Whether to log completions"},
)
num_completions_to_print: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of completions to print. If `log_completions` is `True`, this will be the number of completions logged."
},
)
sync_ref_model: bool | None = Field(
default=False,
json_schema_extra={
@@ -133,3 +139,25 @@ class TRLConfig(BaseModel):
"description": "Epsilon value for clipping in the GRPO algorithm."
},
)
epsilon_high: float | None = Field(
default=None,
json_schema_extra={
"description": "Upper-bound epsilon value for clipping in the GRPO algorithm."
},
)
use_liger_loss: bool | None = Field(
default=None,
json_schema_extra={"description": "Whether to use Liger loss for GRPO."},
)
loss_type: str | None = Field(
default=None,
json_schema_extra={
"description": "Specifies the loss formulation to use. Supported values are `grpo`, `bnpo`, and `dr_grpo`."
},
)
mask_truncated_completions: bool = Field(
default=False,
json_schema_extra={
"description": "When enabled, truncated completions are excluded from the loss calculation."
},
)

View File

@@ -597,6 +597,8 @@ def prepare_optim_env(cfg):
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
else:
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
def prepare_opinionated_env(cfg):

View File

@@ -4,6 +4,7 @@ shared pytest fixtures
import functools
import importlib
import os
import shutil
import sys
import tempfile
@@ -529,31 +530,32 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
# # pylint: disable=redefined-outer-name,unused-argument
# def test_load_fixtures(
# download_smollm2_135m_model,
# download_llama_68m_random_model,
# download_qwen_2_5_half_billion_model,
# download_tatsu_lab_alpaca_dataset,
# download_mhenrichsen_alpaca_2k_dataset,
# download_mhenrichsen_alpaca_2k_w_revision_dataset,
# download_mlabonne_finetome_100k_dataset,
# download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
# download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset,
# download_fozzie_alpaca_dpo_dataset,
# download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
# download_argilla_dpo_pairs_dataset,
# download_tiny_shakespeare_dataset,
# download_deepseek_model_fixture,
# download_huggyllama_model_fixture,
# download_llama_1b_model_fixture,
# download_llama3_8b_model_fixture,
# download_llama3_8b_instruct_model_fixture,
# download_phi_35_mini_model_fixture,
# download_phi_3_medium_model_fixture,
# download_mistral_7b_model_fixture,
# download_gemma_2b_model_fixture,
# download_gemma2_9b_model_fixture,
# download_mlx_mistral_7b_model_fixture,
# download_llama2_model_fixture,
# ):
# pass
@pytest.mark.skipif(
os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1",
reason="Not running in CI cache preload",
)
def test_load_fixtures(
download_smollm2_135m_model,
download_qwen_2_5_half_billion_model,
download_tatsu_lab_alpaca_dataset,
download_mhenrichsen_alpaca_2k_dataset,
download_mhenrichsen_alpaca_2k_w_revision_dataset,
download_mlabonne_finetome_100k_dataset,
download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
download_argilla_dpo_pairs_dataset,
download_tiny_shakespeare_dataset,
download_deepseek_model_fixture,
download_huggyllama_model_fixture,
download_llama_1b_model_fixture,
download_llama3_8b_model_fixture,
download_llama3_8b_instruct_model_fixture,
download_phi_35_mini_model_fixture,
download_phi_3_medium_model_fixture,
download_mistral_7b_model_fixture,
download_gemma_2b_model_fixture,
download_gemma2_9b_model_fixture,
download_mlx_mistral_7b_model_fixture,
download_llama2_model_fixture,
):
pass

View File

@@ -29,6 +29,12 @@ class LogHooksPlugin(BasePlugin):
except FileNotFoundError:
pass
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_trainer_create\n")
def pre_model_load(self, cfg): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
@@ -72,7 +78,7 @@ class LogHooksPlugin(BasePlugin):
f.write("get_trainer_cls\n")
def create_lr_scheduler(
self, cfg, trainer, optimizer
self, cfg, trainer, optimizer, num_training_steps
): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
@@ -165,6 +171,7 @@ class TestPluginHooks:
) as f:
file_contents = f.readlines()
file_contents = "\n".join(file_contents)
assert "post_trainer_create" in file_contents
assert "pre_model_load" in file_contents
assert "post_model_build" in file_contents
assert "pre_lora_load" in file_contents
@@ -172,7 +179,7 @@ class TestPluginHooks:
assert "post_model_load" in file_contents
# assert "create_optimizer" in file_contents # not implemented yet
assert "get_trainer_cls" in file_contents
# assert "create_lr_scheduler" in file_contents # not implemented yet
assert "create_lr_scheduler" in file_contents
assert "add_callbacks_pre_trainer" in file_contents
assert "add_callbacks_post_trainer" in file_contents
assert "post_train" in file_contents

View File

@@ -90,7 +90,7 @@ class TestKnowledgeDistillation:
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -121,5 +121,5 @@ class TestKnowledgeDistillation:
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
)

View File

@@ -166,6 +166,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"""
)
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
@@ -227,7 +228,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "NVL",
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
@@ -257,7 +258,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
f"{get_torch_dist_unique_port()}",
],
env={
"NCCL_P2P_LEVEL": "NVL",
"NCCL_P2P_LEVEL": "LOC",
"NCCL_DEBUG": "INFO",
**current_env,
},
@@ -265,6 +266,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
finally:
recursive_kill(vllm_process)
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
@@ -320,7 +322,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "NVL", # nccl can be brittle, assume P2P isn't reliable
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
@@ -350,7 +352,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
f"{get_torch_dist_unique_port()}",
],
env={
"NCCL_P2P_LEVEL": "NVL",
"NCCL_P2P_LEVEL": "LOC",
"NCCL_DEBUG": "INFO",
**current_env,
},

View File

@@ -479,7 +479,7 @@ class TestMultiGPULlama:
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},

View File

@@ -2,14 +2,19 @@
# pylint: disable=redefined-outer-name
from pathlib import Path
import pytest
import torch
import yaml
from accelerate.state import PartialState
from peft import PeftModelForCausalLM, get_peft_config
from transformers import AutoModelForCausalLM, LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeAttention
from axolotl.cli.config import load_cfg
from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
@@ -24,12 +29,12 @@ from axolotl.utils.dict import DictDefault
MODEL_CONFIGS = [
{
"name": "openaccess-ai-collective/tiny-mistral",
"name": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16,
},
{
"name": "Qwen/Qwen2-7B",
"name": "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16,
},
@@ -39,7 +44,7 @@ MODEL_CONFIGS = [
"dtype": torch.float32,
},
{
"name": "mhenrichsen/gemma-2b",
"name": "trl-internal-testing/tiny-Gemma2ForCausalLM",
"expected_activation": apply_lora_mlp_geglu,
"dtype": torch.float16,
},
@@ -66,29 +71,36 @@ def small_llama_model():
return LlamaForCausalLM(LlamaConfig(**config))
def test_attention_patching_integration():
@pytest.mark.parametrize(
"model_name,attention_cls",
[
("HuggingFaceTB/SmolLM2-135M", LlamaAttention),
("Qwen/Qwen3-30B-A3B", Qwen3MoeAttention),
],
)
def test_attention_patching_integration(model_name, attention_cls):
"""Test attention patching in integration context."""
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
cfg = {"base_model": model_name}
# Store the original implementation
original_forward = getattr(LlamaAttention, "forward")
original_forward = getattr(attention_cls, "forward")
# Apply patch
patch_self_attn_lora(cfg)
# Get the new forward method
patched_forward = LlamaAttention.forward
patched_forward = attention_cls.forward
# Check the forward method was replaced
assert original_forward is not patched_forward
assert patched_forward.__name__ == "axolotl_attn_forward"
# Check original implementation was stored
assert hasattr(LlamaAttention, "_original_forward")
assert hasattr(attention_cls, "_original_forward")
# Clean up
setattr(LlamaAttention, "forward", original_forward)
delattr(LlamaAttention, "_original_forward")
setattr(attention_cls, "forward", original_forward)
delattr(attention_cls, "_original_forward")
def test_swiglu_mlp_integration(small_llama_model):
@@ -144,7 +156,9 @@ def test_swiglu_mlp_integration(small_llama_model):
def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained(
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda:0"
"trl-internal-testing/tiny-Gemma2ForCausalLM",
torch_dtype=torch.float16,
device_map="cuda:0",
)
peft_config = get_peft_config(
{
@@ -413,3 +427,42 @@ def test_kernel_training_integration():
# Verify correct activation function
layer = model.model.model.layers[0]
assert layer.mlp.forward.__func__ is apply_lora_mlp_swiglu
def test_kernel_training_integration_auto_enable(temp_dir):
"""Test model loading with auto-enabled kernel patches."""
# Create minimal config without explicitly setting kernel options
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"sequence_len": 1024,
}
)
# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
# Load config
cfg = load_cfg(str(path))
# Verify kernel options were auto-enabled in the config
assert cfg.lora_mlp_kernel is True
assert cfg.lora_qkv_kernel is True
assert cfg.lora_o_kernel is True

View File

@@ -57,9 +57,9 @@ class Test4dMultipackLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"fp16": True,
}
)
@@ -105,9 +105,9 @@ class Test4dMultipackLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"fp16": True,
}
)

View File

@@ -26,10 +26,15 @@ class TestActivationCheckpointing:
E2E tests for activation checkpointing
"""
@pytest.mark.parametrize(
"gradient_checkpointing",
["offload", "offload_disk"],
)
def test_activation_checkpointing_offload(
self,
temp_dir,
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
gradient_checkpointing,
):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -64,7 +69,7 @@ class TestActivationCheckpointing:
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"gradient_checkpointing": "offload",
"gradient_checkpointing": gradient_checkpointing,
}
)

View File

@@ -6,6 +6,8 @@ import logging
import os
import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
@@ -23,6 +25,7 @@ class TestFalconPatched(unittest.TestCase):
Test case for Falcon models
"""
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_qlora(self, temp_dir):
# pylint: disable=duplicate-code
@@ -71,6 +74,7 @@ class TestFalconPatched(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code

View File

@@ -28,7 +28,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,
@@ -57,9 +57,9 @@ class TestMistral(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
}
)
@@ -76,7 +76,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,
@@ -99,9 +99,9 @@ class TestMistral(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
}
)

View File

@@ -54,9 +54,9 @@ class TestMixtral(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
}
)
@@ -93,9 +93,9 @@ class TestMixtral(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
}
)

View File

@@ -56,7 +56,7 @@ class TestModelPatches(unittest.TestCase):
def test_mistral_multipack(self, temp_dir):
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,

View File

@@ -0,0 +1,63 @@
"""
Test case for handling embeddings when using peft
"""
import torch
from axolotl.train import setup_model_and_tokenizer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
class TestLlamaPeftEmbeddings:
"""
test class for handling embeddings when using peft
"""
def test_peft_embeddings_upcast(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_target_linear": True,
"trust_remote_code": True,
"sequence_len": 512,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": False,
"bf16": "auto",
"save_safetensors": True,
"embeddings_skip_upcast": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
model, _, _, _ = setup_model_and_tokenizer(cfg)
# Check if the embeddings are upcast correctly
# only embed_tokens is a parameter that may be upcast
assert model.base_model.model.model.embed_tokens.weight.dtype == torch.bfloat16
assert model.base_model.model.lm_head.weight.dtype == torch.bfloat16

View File

@@ -56,9 +56,9 @@ class TestPhiMultipack(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"eval_steps": 10,
"save_steps": 10,
"max_steps": 5,
"eval_steps": 3,
"save_steps": 4,
"bf16": "auto",
}
)
@@ -108,9 +108,9 @@ class TestPhiMultipack(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"eval_steps": 10,
"save_steps": 10,
"max_steps": 5,
"eval_steps": 3,
"save_steps": 4,
"bf16": "auto",
}
)

View File

@@ -15,7 +15,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, most_recent_subdir
from ..utils import check_model_output_exists, most_recent_subdir, require_torch_2_6_0
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -26,6 +26,7 @@ class TestResumeLlama:
Test case for resuming training of llama models
"""
@require_torch_2_6_0
def test_resume_lora_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -62,6 +63,7 @@ class TestResumeLlama:
"save_total_limit": 5,
"max_steps": 15,
"use_tensorboard": True,
"save_safetensors": True,
}
)
if is_torch_bf16_gpu_available():

View File

@@ -0,0 +1,62 @@
"""E2E smoke test for evaluate CLI command"""
import os
from pathlib import Path
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
os.environ["WANDB_DISABLED"] = "true"
class TestE2eEvaluate:
"""Test cases for evaluate CLI"""
def test_evaluate(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.evaluate",
str(Path(temp_dir) / "config.yaml"),
]
)

View File

@@ -6,6 +6,8 @@ import logging
import os
import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
@@ -23,6 +25,7 @@ class TestFalcon(unittest.TestCase):
Test case for falcon
"""
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_lora(self, temp_dir):
# pylint: disable=duplicate-code
@@ -74,6 +77,7 @@ class TestFalcon(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_lora_added_vocab(self, temp_dir):
# pylint: disable=duplicate-code
@@ -129,6 +133,7 @@ class TestFalcon(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code

View File

@@ -30,7 +30,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sequence_len": 1024,
"load_in_8bit": True,
@@ -77,7 +77,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sequence_len": 1024,
"val_set_size": 0.02,

View File

@@ -199,3 +199,50 @@ class TestCustomOptimizers(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_came_pytorch(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "came_pytorch",
"adam_beta3": 0.9999,
"adam_epsilon2": 1e-16,
"max_steps": 5,
"lr_scheduler": "cosine",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -648,7 +648,7 @@ class TestValidation(BaseValidation):
DictDefault(
{
"sample_packing": True,
"pad_to_sequence_len": None,
"pad_to_sequence_len": False,
"flash_attention": True,
}
)
@@ -662,6 +662,26 @@ class TestValidation(BaseValidation):
for record in self._caplog.records
)
def test_packing_autoset(self, minimal_cfg):
cfg = (
DictDefault(
{
"sample_packing": True,
"pad_to_sequence_len": None,
"flash_attention": True,
}
)
| minimal_cfg
)
with self._caplog.at_level(logging.INFO):
cfg = validate_config(cfg)
assert any(
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
in record.message
for record in self._caplog.records
)
assert cfg.pad_to_sequence_len is True
def test_merge_lora_no_bf16_fail(self, minimal_cfg):
"""
This is assumed to be run on a CPU machine, so bf16 is not supported.

Some files were not shown because too many files have changed in this diff Show More