Compare commits

...

37 Commits

Author SHA1 Message Date
coderabbitai[bot]
0fccbadb79 📝 Add docstrings to 202512-raise_on_drop
Docstrings generation was requested by @kallewoof.

* https://github.com/axolotl-ai-cloud/axolotl/pull/3321#issuecomment-3668489902

The following files were modified:

* `src/axolotl/utils/data/utils.py`
* `src/axolotl/utils/trainer.py`
2025-12-18 05:49:01 +00:00
Seung Hyun Cho
3e51a680c2 fix: Fix evaluation loss in KD trainer (#3271)
* fix: Fix evaluation loss in KD trainer

* Fix v2 strategy super() call

* fix: Add safety check for total_tokens in log method

* fix: simplified num items and outputs return handling

* fix: add missing model forward pass in compute_loss

* refactor: Use Template Method pattern for chat template strategies

* refactor: use pop(None) and remove v2 override

* chore: lint

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-12-17 13:40:36 -05:00
xzuyn
2cf254b4af Add peft_autocast_adapter_dtype config option (#3311) [skip ci]
* Add `peft_autocast_adapter_dtype` field to schema

* Add `autocast_adapter_dtype` to `model_kwargs`

* chore: docs

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-12-17 10:09:39 -05:00
salman
83d4d97dcc Add QAT NVFP4 configs for blogpost (#3280) [skip ci]
* add configs for blogpost

* fix configs

* fixing baseline configs
2025-12-17 09:35:22 -05:00
NanoCode012
a1d07f42e4 Fix(misc): address PYTORCH_CUDA_ALLOC_CONF deprecate (#3313)
* fix: leftover ministral docs changes

* fix: pytorch_cuda_alloc_conf deprecation

* fix: set old PYTORCH_CUDA_ALLOC_CONF env too

* handle 2.9 separately

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-12-17 09:12:18 -05:00
Wing Lian
2a664dc8ad support for xformers wheels for torch 2.9 (#3308)
* support for xformers wheels for torch 2.9

* fix hf cache?

* don't use hf cache from s3

* show disk free space in ci
2025-12-11 11:56:40 -05:00
NanoCode012
4ac78aa562 fix: update qwen3 jinja tokenization off a few tokens (#3295)
* fix: update qwen3 jinja tokenization off a few tokens

* fix: add note on tokenization issue

* fix: pop last index for mistral tokenizer
2025-12-09 14:31:03 +07:00
VED
b3f4aa149f fix bin size (#3307)
* fix bin size

* lint

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
2025-12-08 09:16:18 -05:00
salman
75b20fb66f Save processor in quantizer CLI (#3290) 2025-12-06 16:27:18 +00:00
NanoCode012
5992e607a2 fix: improve ministral3 docs to be clearer (#3300)
* fix: improve ministral3 docs to be clearer

* fix: title

* chore: wording
2025-12-04 21:44:44 +07:00
NanoCode012
2b66ee189c Feat: add ministral3 (#3297)
* feat: add ministral and mistral3

* chore: lint

* feat: update cce for ministral

* fix: add vram usage

* feat: update for release

* fix: save_pretrained issue in v5

* fix: add instructions to use v5 branch

* fix: add to multipack

* fix: improve instructions

* fix: add model to readme
2025-12-04 08:32:08 -05:00
NanoCode012
86d8cca149 Feat: add trinity by ArceeAI (#3292) 2025-12-02 13:12:55 -05:00
NanoCode012
4a0f98e612 feat: upgrade liger to 0.6.4 (#3289) 2025-12-02 09:16:23 -05:00
Yohan Na
c6ddcdd06a feat: add exaone4 chat template and update enums (#3279)
* feat: add exaone4 chat template and update enums

* fix: handle first message as system or tools in exaone4 chat template

* Update src/axolotl/utils/chat_templates/templates/exaone4.jinja

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* fix: lint

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-12-01 15:52:45 +07:00
github-actions[bot]
7fb6a947d9 chore: update pre-commit hooks (#3287)
Co-authored-by: SalmanMohammadi <25081738+SalmanMohammadi@users.noreply.github.com>
2025-12-01 15:03:14 +07:00
NanoCode012
b234532d9f Feat: add peft_ensure_weight_tying (#3278)
* feat: upgrade peft to 0.18.0

* feat: add peft_ensure_weight_tying

* fix: default

* chore: adjust kwarg per feedback
2025-11-28 18:54:48 +07:00
VED
8990ca3205 fix: removed unused "scikit-learn==1.4.2" (#3277)
Co-authored-by: Ved <ved.work2024@gmail.com>
2025-11-24 13:48:53 +07:00
NanoCode012
006f226270 Feat: add Olmo3 (BC with Olmo and Olmo2) (#3275)
* feat: update cce to include olmo family

* chore: update docs following feedback

* feat: add olmo3 config

* fix: clarify 3 methods

* chore: add olmo to readme
2025-11-24 10:21:31 +07:00
Wing Lian
0b635e69c5 build docker images for 2.9.x (#3273) 2025-11-20 09:26:24 -05:00
Wing Lian
0d27e14e45 Torch 2.9.1 base images (#3268)
* update torch 2.9.1 base images

* update base dockerfile image check
2025-11-20 09:04:37 -05:00
NanoCode012
f5f21fb216 chore: update readme with latest updates (#3267)
Some checks failed
ci-cd / build-axolotl (<nil>, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, 3.11, 2.8.0) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 126, 12.6.3, 3.11, 2.7.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, <nil>, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, 3.11, 2.8.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (vllm, 126, 12.6.3, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 126, 12.6.3, <nil>, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, <nil>, 3.11, 2.8.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (vllm, 126, 12.6.3, true, 3.11, 2.7.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2025-11-18 14:45:21 +07:00
NanoCode012
4e55871112 feat: Add opt-out Telemetry (#3237)
* initial telemetry manager impl

* adding todo

* updates

* updates

* progress on telemetry: config load, process, model load, train start / end, error tracking

* update error file path sanitization function; adding more error tracking

* updated sanitization logic, tests

* adding runtime metrics (cpu + gpu memory, steps/s, etc.)

* tests for runtime metrics telemetry and assoc. callback

* small update / fix

* simplifying path redaction

* sleep on all ranks in distributed setting

* adding back in base_model redaction w/ whitelist

* fix

* doc update

* improved redaction, send system info during model config load telemetry, etc.

* adding runtime metrics / system info additional accelerator support, etc.

* adding runtime metrics / system info additional accelerator support, etc.

* remove duplicate info

* fixes

* fix issue with tests in ci

* distributed fix

* opt-in version of telemetry

* enable / disable logic update

* docs fix

* doc update

* minor fixes

* simplifying

* slight changes

* fix

* lint

* update posthog dep

* coderabbit comments

* fix: opt-in model

* fix: increase time since last

* fix: increase whitelist orgs

* fix: posthog init and shutdown

* fix: imports

* fix: also check grad norm

* fix: duplicate plugin_manager calls

* fix: bad merge

* chore: update docs

* fix: cache process per comment

* fix: error handling

* fix: tests

* Revert "fix: error handling"

This reverts commit 22d1ea5755.

* fix: test telemetry error_handled bool

* fix: revert test

* chore: final doc fixes

---------

Co-authored-by: Dan Saunders <danjsaund@gmail.com>
Co-authored-by: Dan Saunders <dan@axolotl.ai>
2025-11-18 11:35:25 +07:00
Wing Lian
a6bafb55cb upgrade datasets to 4.4.1 (#3266)
* upgrade datasets

* cleanup pip cache earlier

* cleanup unused things from worker

* also cleanup sdist
2025-11-14 09:52:14 -08:00
Wing Lian
0fbde69e9c only push axolotl images, personal repo is deprecated (#3262)
* only push axolotl images, personal repo is deprecated

* cleanup
2025-11-14 07:50:03 -08:00
Wing Lian
301e22849f upgrade to latest deepspeed and make sure latest tagged axolotl images are using torch 2.8.0 (#3261) 2025-11-13 13:03:01 -05:00
VED
dcf24fd24e feat: save checkpoint after training started (#3233)
* add:config parameters for checkpoint

* callback main

* test file_type fix

* lint

* unit

* simplify dict/obj handeling

* Update src/axolotl/utils/schemas/dynamic_checkpoint.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Delete tests/e2e/integrations/__init__.py

* remove hard code path in test

* device check

* lint

* Update src/axolotl/utils/callbacks/dynamic_checkpoint.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update src/axolotl/utils/callbacks/dynamic_checkpoint.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update src/axolotl/utils/schemas/dynamic_checkpoint.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* lint-2

* remove: singal based checkpoints

* lint

* remove signal tests

* add:is_main_process

* lint

* addis_d:istributed() for tests

* remove nested is_main_process

* Update src/axolotl/utils/schemas/dynamic_checkpoint.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Update src/axolotl/utils/schemas/dynamic_checkpoint.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* add user_defined_filename

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-11-13 10:21:05 -05:00
NanoCode012
49b8107989 feat: add granite4 examples (#3256) [skip ci] 2025-11-13 10:19:16 -05:00
NanoCode012
9901ee5602 fix: voxtralprocessor broken (#3255) [skip ci]
* fix: voxtralprocessor broken

* chore: add todo

* chore: wording
2025-11-13 10:18:42 -05:00
xzuyn
dd78f2e0cc Fix: warmup_steps: 0 & warmup_ratio: 0 not disabling warmup (#3254)
* fix unintentional falsy checks

* chore: lint

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-11-11 10:32:06 +07:00
Eduard Zl
b54f9c942b _get_tools in ChatTemplateStrategy : function "parameters" can be dict or string (#3238)
* When training of function calls, "tools" elements of a dataset can contain same parameter name but with different types. Datasets fails to load such training set. This fix allows "parameters" element of function call to be string( by running "json.dumps" in preparation of training data set). The _get_tools function will iterate over tool definitions, if "parameters" element is dict, it will keep that way, if it is a string, it will be converted to dict by invoking "json.loads" on string value.

* feat: add doc on tool parameters json loading

* feat: add tests for parameters json string

---------

Co-authored-by: ezlotnik <eduard_zlotnik@intuit.com>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-11-11 09:04:28 +07:00
NanoCode012
11eb36585a feat: add arg to enable dft in liger (#3125)
* feat: add arg to enable dft in liger

* feat: add tests use_token_scaling

* fix: test

* fix: move check to args
2025-11-10 21:37:47 +07:00
NanoCode012
d0c846fc5e feat: add granitemoeshared and granitemoehybrid (#3158) 2025-11-10 21:35:45 +07:00
Wing Lian
b5fcc2f14b log cumulative total trained tokens (#3252)
* log cumulative total trained tokens

* use is_distributed helper
2025-11-07 16:04:00 -05:00
Wing Lian
b62eed8809 add openenv-core to requirements (#3251) 2025-11-07 12:17:27 -05:00
VED
ed2e8cacd6 feat:openenv rollout_func (#3239) [skip ci]
* feat:openenv rollout_func

* chore lint

* docs

* add:docs processing_class

* tests

* lint
2025-11-07 08:51:40 -05:00
Lê Nam Khánh
80270a92fa Fix typos in some files (#3250) [skip ci] 2025-11-07 08:21:20 -05:00
Wing Lian
bfdc9a8249 upgrade trl and other hf deps (#3249)
* upgrade trl and other hf deps

* skip simpo for now
2025-11-06 16:06:03 -05:00
113 changed files with 5746 additions and 157 deletions

6
.github/FUNDING.yml vendored
View File

@@ -1,13 +1,13 @@
# These are supported funding model platforms
github: [winglian, OpenAccess-AI-Collective] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: axolotl_ai # Replace with a single Ko-fi username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
custom: ['https://quickchart.io/qr?text=bitcoin%3Abc1qxlgwlqwfea5s2cxm42xqsfmwjct0rj8w8ea5np&size=480&centerImageUrl=https%3A%2F%2Fupload.wikimedia.org%2Fwikipedia%2Fcommons%2Fthumb%2F4%2F46%2FBitcoin.svg%2F64px-Bitcoin.svg.png'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']

View File

@@ -57,14 +57,14 @@ jobs:
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
# - cuda: "128"
@@ -90,7 +90,6 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl-base
axolotlai/axolotl-base
- name: Login to Docker Hub
uses: docker/login-action@v2
@@ -147,14 +146,14 @@ jobs:
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
steps:

View File

@@ -25,7 +25,6 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
@@ -36,6 +35,17 @@ jobs:
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -45,7 +55,6 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl
axolotlai/axolotl
tags: |
type=ref,event=branch
@@ -99,7 +108,6 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
@@ -110,6 +118,17 @@ jobs:
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -119,7 +138,6 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl-cloud
axolotlai/axolotl-cloud
tags: |
type=ref,event=branch
@@ -179,7 +197,6 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl-cloud-term
axolotlai/axolotl-cloud-term
tags: |
type=ref,event=branch

View File

@@ -31,7 +31,6 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl
axolotlai/axolotl
tags: |
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
@@ -84,7 +83,6 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl-cloud
axolotlai/axolotl-cloud
tags: |
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}

View File

@@ -59,15 +59,19 @@ jobs:
timeout-minutes: 20
steps:
- name: cleanup node
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
# - name: Restore Cache from S3
# id: hf-cache-restore-s3
# run: |
# mkdir -p ~/.cache/huggingface/hub
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
#
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -91,6 +95,10 @@ jobs:
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
@@ -105,9 +113,13 @@ jobs:
- name: Run tests
run: |
df -h
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
df -h
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
df -h
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
df -h
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
- name: Upload coverage to Codecov
@@ -118,10 +130,6 @@ jobs:
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
@@ -134,15 +142,19 @@ jobs:
timeout-minutes: 20
steps:
- name: cleanup node
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
# - name: Restore Cache from S3
# id: hf-cache-restore-s3
# run: |
# mkdir -p ~/.cache/huggingface/hub
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
#
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -167,6 +179,10 @@ jobs:
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
@@ -176,7 +192,7 @@ jobs:
axolotl --help
- name: Show HF cache
run: huggingface-cli scan-cache
run: hf cache scan
- name: Run tests
run: |
@@ -184,10 +200,6 @@ jobs:
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
gate-skip-e2e:
needs: [pre-commit, pytest, pytest-sdist]
runs-on: ubuntu-latest

View File

@@ -11,13 +11,13 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.3
rev: v0.14.7
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.18.2
rev: v1.19.0
hooks:
- id: mypy
additional_dependencies:
@@ -26,7 +26,7 @@ repos:
'pydantic>=2.5.3',
]
- repo: https://github.com/PyCQA/bandit
rev: 1.8.6
rev: 1.9.2
hooks:
- id: bandit
args: [

View File

@@ -10,6 +10,7 @@ ARG BASE_VOLUME="/runpod-volume"
ENV BASE_VOLUME=$BASE_VOLUME
ENV HF_DATASETS_CACHE="${BASE_VOLUME}/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
ENV HF_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
COPY .runpod/src /src

View File

@@ -29,6 +29,10 @@
## 🎉 Latest Updates
- 2025/12: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3), [Trinity](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/trinity), and [Ministral3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/ministral3).
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss).
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
- 2025/07:
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
@@ -36,12 +40,12 @@
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
<details>
<summary>Expand older updates</summary>
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
@@ -154,6 +158,13 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
## 📈 Telemetry
Axolotl has opt-out telemetry that helps us understand how the project is being used
and prioritize improvements. We collect basic system information, model types, and
error rates—never personal data or file paths. Telemetry is enabled by default. To
disable it, set AXOLOTL_DO_NOT_TRACK=1. For more details, see our [telemetry documentation](https://docs.axolotl.ai/docs/telemetry.html).
## ❤️ Sponsors
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)

View File

@@ -241,6 +241,7 @@ website:
- docs/installation.qmd
- docs/inference.qmd
- docs/cli.qmd
- docs/telemetry.qmd
- docs/config-reference.qmd
- text: "API Reference"
href: docs/api

View File

@@ -51,7 +51,7 @@ RUN git lfs install --skip-repo && \
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
RUN if [ "$PYTORCH_VERSION" = "2.9.1" ] && [ "$CUDA" = "128" ] ; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \

View File

@@ -218,6 +218,13 @@ If you have tool arguments with same name but different dtypes (like `"time": st
```
"arguments": "{\"...\": \"...\"}"
```
The same is applicable for tool parameters.
```
"parameters": "{\"...\": \"...\"}"
```
:::
Example config for Llama4:

View File

@@ -4,7 +4,7 @@ format:
html:
toc: true
toc-depth: 3
number-sections: true
# number-sections: true
code-tools: true
execute:
enabled: false
@@ -14,12 +14,18 @@ This guide covers advanced training configurations for multi-GPU setups using Ax
## Overview {#sec-overview}
Axolotl supports several methods for multi-GPU training:
When training on multiple GPUs, Axolotl supports 3 sharding/parallelism strategies. Additionally, you can layer specific optimization features on top of that strategy.
- DeepSpeed (recommended)
- FSDP (Fully Sharded Data Parallel)
- Sequence parallelism
- FSDP + QLoRA
You generally cannot combine these strategies; they are mutually exclusive.
1. **DeepSpeed**: Powerful optimization library, supports ZeRO stages 1-3.
2. **FSDP (Fully Sharded Data Parallel)**: PyTorch's native sharding implementation (Recommended).
3. **DDP (Distributed Data Parallel)**: PyTorch's native parallelism implementation (Default if neither of the above are selected).
These features can often be combined with the strategies above:
* **Sequence Parallelism**: Splits long sequences across GPUs (Compatible with DDP, DeepSpeed, and FSDP).
* **FSDP + QLoRA**: Combines 4-bit quantization with FSDP (Specific to FSDP).
## DeepSpeed {#sec-deepspeed}
@@ -65,12 +71,18 @@ Start from Stage 1 -> Stage 2 -> Stage 3.
## Fully Sharded Data Parallel (FSDP) {#sec-fsdp}
FSDP allows you to shard model parameters, gradients, and optimizer states across data parallel workers.
::: {.callout-note}
FSDP2 is recommended for new users. FSDP1 is deprecated and will be removed in an upcoming release of Axolotl.
:::
### FSDP + QLoRA {#sec-fsdp-qlora}
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
### Migrating from FSDP1 to FSDP2 {#sec-migrate-fsdp1-fsdp2}
To migrate your config from FSDP1 to FSDP2, you must use the `fsdp_version` top-level config field to specify the FSDP version, and
@@ -145,10 +157,6 @@ single sequence causes OOM errors during model training.
See our [dedicated guide](sequence_parallelism.qmd) for more information.
### FSDP + QLoRA {#sec-fsdp-qlora}
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
## Performance Optimization {#sec-performance}
### Liger Kernel Integration {#sec-liger}

View File

@@ -124,6 +124,8 @@ Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral
```yaml
base_model: mistralai/Voxtral-Mini-3B-2507
processor_type: VoxtralProcessor
```
### Gemma-3 {#sec-gemma-3}

View File

@@ -597,6 +597,116 @@ To see other examples of custom reward functions, please see [TRL GRPO Docs](htt
To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py).
#### OpenEnv Rollout Functions
GRPO supports custom rollout functions for OpenEnv-style environments, enabling interactive tasks like web browsing, code execution, or tool use. This allows you to implement custom generation logic that interacts with external environments.
For example, to implement a simple math-solving environment with step-by-step verification:
```python
# math_env.py
import re
def math_solver_rollout(model, processing_class, prompts, generation_config=None):
"""
Custom rollout function that generates step-by-step math solutions.
Args:
model: The language model
processing_class: The tokenizer/processing_class
prompts: List of prompt dicts (with 'messages' key for chat format)
generation_config: Optional generation configuration
Returns:
List of completion strings
"""
completions = []
for prompt in prompts:
# Apply chat template to prompt
messages = prompt.get("messages", [])
formatted_prompt = processing_class.apply_chat_template(
messages, processing_class=False, add_generation_prompt=True
)
# Generate step-by-step solution
full_response = ""
for step in range(5): # Max 5 reasoning steps
current_input = formatted_prompt + full_response + "\nNext step:"
inputs = processing_class(current_input, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=100,
generation_config=generation_config,
)
step_text = processing_class.decode(
outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
# Check if solution is complete
if "FINAL ANSWER:" in step_text:
full_response += step_text
break
full_response += step_text + "\n"
completions.append(full_response)
return completions
def math_reward(prompts, completions, answers, **kwargs):
"""Reward function that checks mathematical correctness"""
rewards = []
for completion, correct_answer in zip(completions, answers):
# Extract predicted answer
match = re.search(r"FINAL ANSWER:\s*(.+)", completion)
predicted = match.group(1).strip() if match else ""
# Compare with correct answer
reward = 1.0 if predicted == str(correct_answer) else 0.0
rewards.append(reward)
return rewards
def math_transform(cfg, *args, **kwargs):
"""Transform dataset to GRPO format with answer field"""
def transform_fn(example, processing_class=None):
return {
"prompt": [{"role": "user", "content": example["question"]}],
"answer": str(example["answer"]),
}
return transform_fn, {"remove_columns": ["question"]}
```
```yaml
rl: grpo
trl:
beta: 0.001
max_completion_length: 512
num_generations: 4
rollout_func: "math_env.math_solver_rollout" # Custom rollout function
reward_funcs: ["math_env.math_reward"]
reward_weights: [1.0]
datasets:
- path: openai/gsm8k
name: main
type: math_env.math_transform
```
The `rollout_func` parameter accepts a fully qualified name (e.g., `module_name.function_name`) that points to a callable function in your local directory. The function receives:
- `model`: The language model
- `processing_class`: The tokenizer/processing class
- `prompts`: List of prompt dictionaries
- `generation_config` (optional): Generation configuration
And should return a list of completion strings.
For more OpenEnv examples, see [TRL OpenEnv Documentation](https://huggingface.co/docs/trl/main/en/openenv).
#### GRPO with DAPO/Dr. GRPO loss
The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.

61
docs/telemetry.qmd Normal file
View File

@@ -0,0 +1,61 @@
---
title: Telemetry
description: A description of the telemetry implementation in Axolotl.
---
# Telemetry in Axolotl
Axolotl implements anonymous telemetry to help maintainers understand how the library
is used and where users encounter issues. This data helps prioritize features, optimize
performance, and fix bugs.
## Data Collection
We collect:
- System info: OS, Python version, Axolotl version, PyTorch version, Transformers
version, etc.
- Hardware info: CPU count, memory, GPU count and models
- Runtime metrics: Training progress, memory usage, timing information
- Usage patterns: Models (from a whitelist) and configurations used
- Error tracking: Stack traces and error messages (sanitized to remove personal
information)
Personally identifiable information (PII) is not collected.
## Implementation
Telemetry is implemented using PostHog and consists of:
- `axolotl.telemetry.TelemetryManager`: A singleton class that initializes the
telemetry system and provides methods for tracking events.
- `axolotl.telemetry.errors.send_errors`: A decorator that captures exceptions and
sends sanitized stack traces.
- `axolotl.telemetry.runtime_metrics.RuntimeMetricsTracker`: A class that tracks
runtime metrics during training.
- `axolotl.telemetry.callbacks.TelemetryCallback`: A Trainer callback that sends
runtime metrics telemetry.
The telemetry system will block training startup for 10 seconds to ensure users are
aware of data collection, unless telemetry is explicitly enabled or disabled.
## Opt-Out Mechanism
Telemetry is **enabled by default** on an opt-out basis. To disable it, set
`AXOLOTL_DO_NOT_TRACK=1` or `DO_NOT_TRACK=1`.
A warning message will be logged on start to clearly inform users about telemetry.
We will remove this after some period.
To hide the warning message about telemetry that is displayed on train, etc. startup,
explicitly set: `AXOLOTL_DO_NOT_TRACK=0` (enable telemetry) or `AXOLOTL_DO_NOT_TRACK=1`
(explicitly disable telemetry).
## Privacy
- All path-like config information is automatically redacted from telemetry data
- Model information is only collected for whitelisted organizations
- See `axolotl/telemetry/whitelist.yaml` for the set of whitelisted organizations
- Each run generates a unique anonymous ID
- This allows us to link different telemetry events in a single same training run
- Telemetry is only sent from the main process to avoid duplicate events

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88\""
]
},
{
@@ -253,7 +253,6 @@
"source": [
"from axolotl.utils import set_pytorch_cuda_alloc_conf\n",
"\n",
"# Set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n",
"set_pytorch_cuda_alloc_conf()"
]
},

View File

@@ -0,0 +1,65 @@
# Finetune IBM's Granite 4.0 with Axolotl
[Granite 4.0](https://huggingface.co/collections/ibm-granite/granite-40-language-models) are a family of open source models trained by IBM Research.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Granite4 is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.7.1 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Run the finetuning example:
```bash
axolotl train examples/granite4/granite-4.0-tiny-fft.yaml
```
This config uses about 40.8GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
### Limitation
Adapter finetuning does not work at the moment. It would error with
```bash
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x3072 and 1x1179648)
```
In addition, if adapter training works, `lora_target_linear: true` will not work due to:
```bash
ValueError: Target module GraniteMoeHybridParallelExperts() is not supported.
```
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [Granite Docs](https://www.ibm.com/granite/docs/models/granite)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,45 @@
base_model: ibm-granite/granite-4.0-tiny-preview
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/model-out
sequence_len: 2048
sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -13,7 +13,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```

View File

@@ -0,0 +1,50 @@
# Finetune Ministral with Axolotl
Ministral is a family of openweight models from MistralAI found on [HuggingFace](mistralai/Ministral-8B-Instruct-2410). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
axolotl train examples/ministral/ministral-small-qlora.yaml
```
This config uses about 8.76 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### Tips
- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Limitations
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
In addition, we do not support overriding tokens yet.
## Related Resources
- [MistralAI Ministral Blog](https://mistral.ai/news/ministraux)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
## Future Work
- Add parity to Preference Tuning, RL, etc.
- Add parity to other tokenizer configs like overriding tokens.

View File

@@ -0,0 +1,67 @@
base_model: mistralai/Ministral-8B-Instruct-2410
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,79 @@
# Finetune Ministral3 with Axolotl
Ministral3 is a family of open-weight models from MistralAI found on [HuggingFace](https://huggingface.co/collections/mistralai/ministral-3). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
Please see [Thinking](#thinking) and [Vision](#vision) for their respective fine-tuning.
Thanks to the team at MistralAI for giving us early access to prepare for these releases.
Note: This is still experimental given it is based on transformers v5 RC.
## Getting started
1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Swap to the Axolotl transformers v5 branch
```bash
cp examples/ministral3/ministral3-3b-qlora.yaml ministral3-3b-qlora.yaml
git fetch
git checkout transformers-v5
# Install packages for transformers v5
pip install -e .
```
4. Run the fine-tuning:
```bash
axolotl train ministral3-3b-qlora.yaml
```
Let us know how it goes. Happy finetuning! 🚀
### Tips
- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
### Thinking
Ministral3 2512 model supports thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps.
📚 **[See the Thinking fine-tuning guide →](./think/README.md)**
### Vision
Ministral3 2512 model also supports vision capabilities.
📚 **[See the Vision fine-tuning guide →](./vision/README.md)**
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Limitations
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
In addition, we do not support overriding tokens yet.
## Related Resources
- [MistralAI Mistral3 Blog](https://mistral.ai/news/mistral-3)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
## Future Work
- Add parity to Preference Tuning, RL, etc.
- Add parity to other tokenizer configs like overriding tokens.

View File

@@ -0,0 +1,67 @@
base_model: mistralai/Ministral-3-3B-Reasoning-2512
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,73 @@
# Ministral3 2512 Thinking Fine-tuning
This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collections/mistralai/ministral-3) with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections.
## Prerequisites
Before starting, ensure you have:
- Installed Axolotl (see [main README](../README.md))
## Getting Started
Run the thinking model fine-tuning:
```bash
axolotl train examples/ministral3/think/ministral3-3b-think-qlora.yaml
```
This config uses about 4.76 GiB VRAM.
### Tips
- Dataset uses multi-content format with `type: thinking` support. See [Dataset Format](#dataset-format) below.
- You cannot mix `content: str` and `content: list[dict]`, otherwise, dataset loading will fail. Keep it consistent.
## Dataset Format
The thinking model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
Example format:
```json
{
"messages": [
{
"role": "system",
"content": [
{ "type": "text", "text": "{SYSTEM_PROMPT}"}
]
},
{
"role": "user",
"content": [
{ "type": "text", "text": "Solve this step by step: What is 15% of 240?"}
]
},
{
"role": "assistant",
"content": [
{
"type": "thinking",
"thinking": "I need to calculate 15% of 240. First, I'll convert 15% to decimal: 0.15. Then multiply: 0.15 × 240 = 36."
},
{
"type": "text",
"text": "To find 15% of 240, I'll multiply 240 by 0.15:\n\n240 × 0.15 = 36\n\nTherefore, 15% of 240 is 36."
}
]
}
]
}
```
### Advanced Options
The `thinking` section supports an optional `closed` parameter:
```json
{
"type": "thinking",
"thinking": "Internal reasoning here...",
"closed": true // Default: true, controls adding the closing [/THINK] tag
}
```

View File

@@ -0,0 +1,67 @@
base_model: mistralai/Ministral-3-3B-Reasoning-2512
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: Nanobit/text-think-2k-test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,57 @@
# Ministral3 2512 Vision Fine-tuning
This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collections/mistralai/ministral-3) with vision capabilities using Axolotl.
## Prerequisites
Before starting, ensure you have:
- Installed Axolotl from source (see [main README](../README.md#getting-started))
## Getting started
1. Install the required vision lib:
```bash
pip install 'mistral-common[opencv]==1.8.6'
```
2. Download the example dataset image:
```bash
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
```
3. Run the fine-tuning:
```bash
axolotl train examples/ministral3/vision/ministral3-3b-vision-qlora.yml
```
WARNING: The loss and grad norm will be much higher than normal at first. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.
### Tips
Key differences from text-only model:
- Multi-modal dataset format required
- Sample packing not supported
## Dataset Format
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
Example:
```json
{
"messages": [
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
{"role": "user", "content": [
{ "type": "text", "text": "What's in this image?"},
{"type": "image", "path": "path/to/image.jpg" }
]},
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
],
}
```
## Limitations
- Sample Packing is not supported for multi-modality training currently.

View File

@@ -0,0 +1,64 @@
base_model: mistralai/Ministral-3-3B-Reasoning-2512
processor_type: AutoProcessor
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
# sample dataset below requires downloading image in advance
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
datasets:
- path: Nanobit/text-vision-2k-test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
lora_model_dir:
sequence_len: 2048
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

38
examples/olmo3/README.md Normal file
View File

@@ -0,0 +1,38 @@
# Finetune Allenai's Olmo 3 with Axolotl
[Olmo 3](https://huggingface.co/collections/allenai/olmo-3) are a family of 7B and 32B models open source models trained by The Allen Institute for Artificial Intelligence.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
axolotl train examples/olmo3/olmo3-7b-qlora.yaml
```
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- The example config can be re-used for Olmo and Olmo 2.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [Olmo 3 Blog](https://allenai.org/blog/olmo3)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,64 @@
base_model: allenai/Olmo-3-7B-Instruct-SFT
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,67 @@
base_model: google/gemma-3-12b-it
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: gemma3
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/out_gemma/
sequence_len: 8096
sample_packing: true
flash_attention: true
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 4e-5
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
# evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,72 @@
base_model: google/gemma-3-12b-it
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: gemma3
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/qat_out_gemma/
sequence_len: 8096
sample_packing: true
flash_attention: true
qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 4e-5
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,67 @@
base_model: google/gemma-3-12b-it
# Math finetuning configuration for Gemma3-12B
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: gemma3
datasets:
- path: AI-MO/NuminaMath-CoT
type: chat_template
output_dir: ./outputs/out_math_gemma/
sequence_len: 4096
sample_packing: true
flash_attention: true
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 8
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 3e-5
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
# evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,72 @@
base_model: google/gemma-3-12b-it
# Math finetuning configuration for Gemma3-12B
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: gemma3
datasets:
- path: AI-MO/NuminaMath-CoT
type: chat_template
output_dir: ./outputs/qat_out_math_gemma/
sequence_len: 4096
sample_packing: true
flash_attention: true
qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 8
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 3e-5
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
# evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,68 @@
base_model: google/gemma-3-27b-it
# Math finetuning configuration for Gemma3-27B
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: gemma3
datasets:
- path: AI-MO/NuminaMath-CoT
type: chat_template
output_dir: ./outputs/out_math_gemma27/
sequence_len: 4096
sample_packing: true
flash_attention: true
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-6
eta_min: 7e-7
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
# evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,73 @@
base_model: google/gemma-3-27b-it
# Math finetuning configuration for Gemma3-27B
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: gemma3
datasets:
- path: AI-MO/NuminaMath-CoT
type: chat_template
output_dir: ./outputs/qat_out_math_gemma27/
sequence_len: 4096
sample_packing: true
flash_attention: true
qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-6
eta_min: 7e-7
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
# evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,67 @@
base_model: Qwen/Qwen2.5-72B
# Math finetuning configuration for Qwen2.5-72B (non-instruct)
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: qwen_25
datasets:
- path: AI-MO/NuminaMath-CoT
type: chat_template
output_dir: ./outputs/out_math_72b/
sequence_len: 4096
sample_packing: true
flash_attention: true
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 8
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-6
eta_min: 7e-7
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
# evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,72 @@
base_model: Qwen/Qwen2.5-72B
# Math finetuning configuration for Qwen2.5-72B (non-instruct)
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: qwen_25
datasets:
- path: AI-MO/NuminaMath-CoT
type: chat_template
output_dir: ./outputs/qat_out_math_72b/
sequence_len: 4096
sample_packing: true
flash_attention: true
qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 8
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-6
eta_min: 7e-7
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
# evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,67 @@
base_model: Qwen/Qwen2.5-72B
# Alpaca finetuning configuration for Qwen2.5-72B
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: qwen_25
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/out_qwen72b/
sequence_len: 8096
sample_packing: true
flash_attention: true
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
# evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,72 @@
base_model: Qwen/Qwen2.5-72B
# Alpaca finetuning configuration for Qwen2.5-72B
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: qwen_25
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/qat_out_qwen72b/
sequence_len: 8096
sample_packing: true
flash_attention: true
qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
# evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

46
examples/qwen3/README.md Normal file
View File

@@ -0,0 +1,46 @@
# Finetune Qwen3 with Axolotl
[Qwen3](https://huggingface.co/collections/Qwen/qwen3) are a family of open source models trained by Alibaba.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
axolotl train examples/qwen3/32b-qlora.yaml
```
Let us know how it goes. Happy finetuning! 🚀
### Chat template masking a few tokens off
If you notice that the `chat_template` masking for assistant prompts are off by a few tokens, please ensure that you are adding the below to the yaml.
```yaml
chat_template: qwen3
```
### TIPS
- For inference, please check the official model card as it depends on your reasoning mode.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [Qwen3 Blog](https://qwenlm.github.io/blog/qwen3/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -6,21 +6,17 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Seed-OSS is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from main for pip:
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install Cut Cross Entropy
python scripts/cutcrossentropy_install.py | sh
```
# Install Cut Cross Entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Run the finetuning example:
@@ -41,9 +37,7 @@ Let us know how it goes. Happy finetuning! 🚀
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources

View File

@@ -37,9 +37,7 @@ This guide shows how to fine-tune SmolVLM2 models with Axolotl.
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources

View File

@@ -0,0 +1,38 @@
# Finetune ArceeAI's Trinity with Axolotl
[Trinity](https://huggingface.co/collections/arcee-ai/trinity) is a family of open weight MoE models trained by Arcee.ai.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Run the finetuning example:
```bash
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
```
This config uses about 24.9 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- For inference, the official Arcee.ai team recommends `top_p: 0.75`, `temperature: 0.15`, `top_k: 50`, and `min_p: 0.06`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,67 @@
base_model: arcee-ai/Trinity-Nano-Preview
trust_remote_code: true
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# CCE - N/A as of now
# plugins:
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
# flash_attention: true # Not supported
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,5 +1,5 @@
base_model: mistralai/Voxtral-Mini-3B-2507
processor_type: AutoProcessor
processor_type: VoxtralProcessor
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

View File

@@ -1,23 +1,23 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.47.0
bitsandbytes==0.48.2
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
liger-kernel==0.6.3
liger-kernel==0.6.4
# END section
packaging==23.2
huggingface_hub>=0.36.0
peft>=0.17.1
tokenizers>=0.21.1
peft>=0.18.0
tokenizers>=0.22.1
transformers==4.57.1
accelerate==1.10.1
datasets==4.3.0
accelerate==1.11.0
datasets==4.4.1
deepspeed>=0.17.0
trl==0.24.0
trl==0.25.0
hf_xet==1.2.0
kernels>=0.9.0
trackio
@@ -42,7 +42,6 @@ numpy>=2.2.6
# qlora things
evaluate==0.4.1
scipy
scikit-learn==1.4.2
nvidia-ml-py==12.560.30
art
tensorboard
@@ -64,9 +63,13 @@ immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.13.0
openenv-core==0.1.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.7
axolotl-contribs-mit==0.0.5
mistral-common==1.8.5
# telemetry
posthog==6.7.11
mistral-common==1.8.6

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"'
)

View File

@@ -66,7 +66,6 @@ def parse_requirements(extras_require_map):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
extras_require_map["vllm"] = ["vllm==0.11.1"]
_install_requires.pop(_install_requires.index(xformers_version))
elif (major, minor) >= (2, 8):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
@@ -130,7 +129,7 @@ extras_require = {
"ring-flash-attn>=0.1.7",
],
"deepspeed": [
"deepspeed==0.17.5",
"deepspeed==0.18.2",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -14,6 +14,8 @@ import yaml
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager
from axolotl.telemetry.errors import send_errors
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
@@ -31,6 +33,8 @@ LOG = get_logger(__name__)
API_KEY_FIELDS = {"comet_api_key"}
TELEMETRY_MANAGER = TelemetryManager.get_instance()
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
"""
@@ -164,6 +168,7 @@ def plugin_set_cfg(cfg: DictDefault):
plugin_manager.cfg = cfg
@send_errors
def load_cfg(
config: str | Path | DictDefault = Path("examples/"), **kwargs
) -> DictDefault:
@@ -197,6 +202,8 @@ def load_cfg(
temp_file.close()
cfg.axolotl_config_path = temp_file.name
TELEMETRY_MANAGER.send_event(event_type="config-loaded", properties=cfg)
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
cfg_keys = cfg.keys()
@@ -240,6 +247,7 @@ def load_cfg(
setup_comet_env_vars(cfg)
plugin_set_cfg(cfg)
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
cfg_to_log = {
k: "[REDACTED]" if k in API_KEY_FIELDS else v
for k, v in cfg.items()

View File

@@ -19,7 +19,10 @@ from axolotl.cli.utils.diffusion import (
launch_diffusion_gradio_ui,
)
from axolotl.integrations.base import PluginManager
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.telemetry.errors import send_errors
from axolotl.utils.chat_templates import (
get_chat_template_from_config,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -43,6 +46,7 @@ def get_multi_line_input() -> str:
return instruction
@send_errors
def do_inference(
*,
cfg: DictDefault,
@@ -160,6 +164,7 @@ def do_inference(
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
@send_errors
def do_inference_gradio(
*,
cfg: DictDefault,

View File

@@ -26,7 +26,7 @@ from axolotl.cli.utils import (
launch_training,
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils import set_misc_env, set_pytorch_cuda_alloc_conf
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import AxolotlInputConfig
@@ -45,6 +45,7 @@ def cli():
print_axolotl_text_art()
load_dotenv()
set_pytorch_cuda_alloc_conf()
set_misc_env()
@cli.command()

View File

@@ -7,12 +7,14 @@ import fire
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@send_errors
def do_merge_lora(*, cfg: DictDefault) -> None:
"""
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config

View File

@@ -23,6 +23,7 @@ from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.config import load_cfg
from axolotl.telemetry.errors import send_errors
from axolotl.utils.logging import get_logger
from axolotl.utils.train import determine_last_checkpoint
@@ -118,6 +119,7 @@ def _distributed_checkpoint_to_merged_weights(
return save_path_
@send_errors
def merge_fsdp_weights(
checkpoint_dir: str,
output_path: str,

View File

@@ -17,6 +17,7 @@ from axolotl.cli.config import load_cfg
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.trainer import disable_datasets_caching
@@ -24,6 +25,7 @@ from axolotl.utils.trainer import disable_datasets_caching
LOG = get_logger(__name__)
@send_errors
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
"""
Preprocesses dataset specified in axolotl config.

View File

@@ -8,7 +8,7 @@ from typing import Union
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.logging import get_logger
from axolotl.utils.quantization import (
TorchAOQuantDType,
@@ -66,6 +66,11 @@ def do_quantize(
LOG.info(f"Loading model from {model_path}.")
tokenizer = load_tokenizer(cfg)
processor = None
if cfg.is_multimodal:
processor = load_processor(cfg, tokenizer)
config = AutoConfig.from_pretrained(model_path)
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
model = AutoModelForCausalLM.from_pretrained(
@@ -107,6 +112,10 @@ def do_quantize(
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
if processor:
LOG.info(f"Saving processor to: {str(Path(output_dir) / 'quantized')}.")
processor.save_pretrained(str(Path(output_dir) / "quantized"))
if hub_model_id:
hub_model_id = (
hub_model_id.rstrip("-")
@@ -114,6 +123,8 @@ def do_quantize(
)
model.push_to_hub(hub_model_id, safe_serialization=False)
tokenizer.push_to_hub(hub_model_id)
if processor:
processor.push_to_hub(hub_model_id)
LOG.info(f"Quantized model pushed to: {hub_model_id}.")
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")

View File

@@ -17,4 +17,5 @@ MOE_ARCH_BLOCK = {
"deepseek_v3": "DeepseekV3MoE",
"gpt_oss": "GptOssDecoderLayer",
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
"afmoe": "AfmoeMoE",
}

View File

@@ -9,6 +9,7 @@ from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.telemetry.errors import send_errors
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -34,6 +35,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
)
@send_errors
def load_datasets(
*,
cfg: DictDefault,
@@ -96,6 +98,7 @@ def load_datasets(
)
@send_errors
def load_preference_datasets(
*, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None
) -> TrainDatasetMeta:

View File

@@ -29,6 +29,8 @@ from transformers.trainer_pt_utils import AcceleratorConfig
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
from axolotl.telemetry.callbacks import TelemetryCallback
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils import (
is_comet_available,
is_mlflow_available,
@@ -118,6 +120,13 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
if self.cfg.dynamic_checkpoint and self.cfg.dynamic_checkpoint.enabled:
from axolotl.utils.callbacks.dynamic_checkpoint import (
DynamicCheckpointCallback,
)
callbacks.append(DynamicCheckpointCallback(self.cfg))
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
@@ -155,6 +164,10 @@ class TrainerBuilderBase(abc.ABC):
)
)
telemetry_manager = TelemetryManager.get_instance()
if telemetry_manager.enabled:
callbacks.append(TelemetryCallback())
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -196,9 +209,9 @@ class TrainerBuilderBase(abc.ABC):
):
warmup_steps = 0
warmup_ratio = 0.0
if self.cfg.warmup_steps:
if self.cfg.warmup_steps is not None:
warmup_steps = self.cfg.warmup_steps
elif self.cfg.warmup_ratio:
elif self.cfg.warmup_ratio is not None:
if total_num_steps:
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
else:

View File

@@ -43,7 +43,7 @@ from axolotl.core.trainers.utils import (
from axolotl.utils import get_not_null
from axolotl.utils.bench import get_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.distributed import is_distributed, is_main_process
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -350,6 +350,11 @@ class AxolotlTrainer(
# track number of tokens for tokens per second calculation
if self.args.include_tkps:
inputs_key = "labels" if "labels" in inputs else "input_ids"
num_tokens = (inputs[inputs_key] != -100).sum()
if is_distributed():
torch.distributed.all_reduce(
num_tokens, op=torch.distributed.ReduceOp.SUM
)
if hasattr(self.state, "num_tokens"):
self.state.num_tokens = (
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
@@ -357,6 +362,11 @@ class AxolotlTrainer(
else:
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
if hasattr(self.state, "total_tokens"):
self.state.total_tokens += num_tokens
else:
self.state.total_tokens = num_tokens
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
@@ -621,6 +631,11 @@ class AxolotlTrainer(
logs["tokens_per_second_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
if (
hasattr(self.state, "total_tokens")
and self.state.total_tokens is not None
):
logs["total_tokens"] = int(self.state.total_tokens.item())
del self._stored_metrics[train_eval]

View File

@@ -126,6 +126,9 @@ class GRPOStrategy:
if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
if trl.rollout_func:
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
return grpo_args_kwargs
@classmethod
@@ -201,3 +204,32 @@ class GRPOStrategy:
raise ValueError(
f"Reward function {reward_func_fqn} not found."
) from exc
@classmethod
def get_rollout_func(cls, rollout_func_fqn: str):
"""
Returns the rollout function from the given fully qualified name.
Args:
rollout_func_fqn (str): Fully qualified name of the rollout function
(e.g. my_module.my_rollout_func)
Returns:
Callable rollout function
"""
try:
rollout_func_module_name = rollout_func_fqn.split(".")[-1]
rollout_func_module = importlib.import_module(
".".join(rollout_func_fqn.split(".")[:-1])
)
rollout_func = getattr(rollout_func_module, rollout_func_module_name)
if not callable(rollout_func):
raise ValueError(
f"Rollout function {rollout_func_fqn} must be callable"
)
return rollout_func
except ModuleNotFoundError as exc:
raise ValueError(f"Rollout function {rollout_func_fqn} not found.") from exc

View File

@@ -10,6 +10,7 @@ import torch
from datasets import Dataset
from transformers.trainer import Trainer
from axolotl.telemetry.errors import send_errors
from axolotl.train import (
TrainDatasetMeta,
setup_model_and_tokenizer,
@@ -63,6 +64,7 @@ def evaluate_dataset(
return metrics
@send_errors
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
"""
Evaluate a model on training and validation datasets.

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"
```
## Usage
@@ -61,10 +61,15 @@ plugins:
- llama4
- llama4_text
- llava
- ministral
- ministral3
- mistral
- mistral3
- mixtral
- mllama
- olmo
- olmo2
- olmo3
- phi
- phi3
- phi4_multimodal

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"`'
)

View File

@@ -179,8 +179,17 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
logprobs = prompt.pop(self.logprobs_field)
tokenized_prompt = super()._tokenize_single_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
# let subclasses add fields before transform
tokenized_prompt = self._prepare_kd_fields(tokenized_prompt, prompt)
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
return tokenized_prompt
def _prepare_kd_fields(self, tokenized_prompt, original_prompt):
"""
Hook for subclasses to prepare additional KD fields before transform
"""
return tokenized_prompt
@@ -283,14 +292,13 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
return sample
def _tokenize_single_prompt(self, prompt):
target_token_ids = prompt.get("target_token_ids", None)
tokenized_prompt = super()._tokenize_single_prompt(prompt)
def _prepare_kd_fields(self, tokenized_prompt, original_prompt):
"""
Add pre-tokenized target_token_ids for v2 format
"""
target_token_ids = original_prompt.pop("target_token_ids", None)
if target_token_ids is not None:
tokenized_prompt["target_token_ids"] = target_token_ids
return tokenized_prompt

View File

@@ -16,6 +16,8 @@
KD trainer
"""
from typing_extensions import override
from axolotl.core.trainers.base import AxolotlTrainer
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
@@ -60,6 +62,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
if columns_to_add:
self._signature_columns += columns_to_add
@override
def compute_loss(
self,
model,
@@ -79,10 +82,22 @@ class AxolotlKDTrainer(AxolotlTrainer):
):
del inputs["attention_mask"]
if num_items_in_batch is None and "labels" in inputs:
num_items_in_batch = (inputs["labels"] != -100).sum().item()
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
return outputs[0]
if isinstance(outputs, dict):
loss = outputs["loss"]
elif isinstance(outputs, tuple):
loss = outputs[0]
else:
loss = outputs.loss if hasattr(outputs, "loss") else outputs
return (loss, outputs) if return_outputs else loss

View File

@@ -18,6 +18,9 @@ liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
# FLCE-specific
liger_use_token_scaling: true
```
## Supported Models

View File

@@ -16,7 +16,7 @@
Module for handling LIGER input arguments.
"""
from pydantic import BaseModel, model_validator
from pydantic import BaseModel, Field, model_validator
from axolotl.utils.logging import get_logger
@@ -35,6 +35,15 @@ class LigerArgs(BaseModel):
liger_glu_activation: bool | None = None
liger_cross_entropy: bool | None = None
liger_fused_linear_cross_entropy: bool | None = None
liger_use_token_scaling: bool | None = Field(
default=None,
json_schema_extra={
"description": (
"Enables use_token_scaling in fused_linear_cross_entropy. "
"When True, each token's loss is multiplied by its predicted probability (detached from gradients)."
)
},
)
@model_validator(mode="before")
@classmethod
@@ -75,6 +84,18 @@ class LigerArgs(BaseModel):
)
return data
@model_validator(mode="before")
@classmethod
def check_liger_use_token_scaling_flce(cls, data):
if data.get("liger_use_token_scaling") and not data.get(
"liger_fused_linear_cross_entropy"
):
raise ValueError(
"`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled."
)
return data
@model_validator(mode="after")
def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self):
# TODO @SalmanMohammadi this is a larger fix - investigate

View File

@@ -48,6 +48,33 @@ class LigerPlugin(BasePlugin):
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
)
if cfg.liger_use_token_scaling:
# Patch FLCE to set token_scaling=True for function and class API
from liger_kernel.transformers import functional
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)
old_liger_fused_linear_cross_entropy = (
functional.liger_fused_linear_cross_entropy
)
def patched_liger_fused_linear_cross_entropy(*args, **kwargs):
kwargs["use_token_scaling"] = True
return old_liger_fused_linear_cross_entropy(*args, **kwargs)
functional.liger_fused_linear_cross_entropy = (
patched_liger_fused_linear_cross_entropy
)
old_init = LigerFusedLinearCrossEntropyLoss.__init__
def patched_init(self, *args, **kwargs):
kwargs["use_token_scaling"] = True
return old_init(self, *args, **kwargs)
LigerFusedLinearCrossEntropyLoss.__init__ = patched_init
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)

View File

@@ -20,6 +20,7 @@ from peft import (
from transformers import PreTrainedModel
from axolotl.loaders.utils import get_linear_embedding_layers
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -101,6 +102,8 @@ def load_lora(
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
if cfg.peft_trainable_token_indices:
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
if cfg.peft_ensure_weight_tying is not None:
lora_config_kwargs["ensure_weight_tying"] = cfg.peft_ensure_weight_tying
# Determine the correct PEFT task type
model_cls = type(model).__name__
@@ -139,9 +142,12 @@ def load_lora(
):
setup_quantized_meta_for_peft(model)
model_kwargs: Any = {}
if cfg.peft_autocast_adapter_dtype is not None:
model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype
if cfg.lora_model_dir:
LOG.debug("Loading pretrained PEFT - LoRA")
model_kwargs: Any = {}
if cfg.lora_on_cpu:
model_kwargs["max_memory"] = {"cpu": "256GiB"}
model_kwargs["device_map"] = {"": "cpu"}
@@ -152,7 +158,7 @@ def load_lora(
**model_kwargs,
)
else:
model = get_peft_model(model, lora_config)
model = get_peft_model(model, lora_config, **model_kwargs)
if rank == 0:
try:
@@ -172,6 +178,7 @@ def load_lora(
return model, lora_config
@send_errors
def load_adapter(
model: PreTrainedModel,
cfg: DictDefault,

View File

@@ -49,6 +49,7 @@ from axolotl.loaders.utils import (
load_model_config,
)
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.telemetry.errors import send_errors
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import (
@@ -158,6 +159,7 @@ class ModelLoader:
"""Property that determines if FSDP with QLoRA is enabled."""
return self.is_fsdp_enabled and self.cfg.adapter == "qlora"
@send_errors
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
"""Load and prepare the model with all configurations and patches.

View File

@@ -457,7 +457,7 @@ class PatchManager:
and self.cfg.flash_attention
and not self.inference
):
# TODO(MengqingCao): split these patches seperately
# TODO(MengqingCao): split these patches separately
from axolotl.monkeypatch.llama_attn_hijack_flash import (
is_xformers_swiglu_available,
replace_llama_mlp_with_swiglu,

View File

@@ -1,27 +1,47 @@
"""Processor loading functionality for multi-modal models"""
from typing import Any
import transformers
from transformers import (
AutoProcessor,
PreTrainedTokenizerBase,
)
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@send_errors
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
processor_kwargs: dict[str, Any] = {} # Do we actually need this?
processor_cls = AutoProcessor
if cfg.processor_type:
processor_cls = getattr(transformers, cfg.processor_type)
if cfg.tokenizer_use_mistral_common:
def _patch_mistralcommontokenizer():
"""
Transformers v5 stops reading the sub-processor.
We need to patch this, so both processors use this.
"""
import transformers.tokenization_mistral_common as tokenization_mistral_common
from axolotl.utils.mistral import HFMistralTokenizer
tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer
_patch_mistralcommontokenizer()
from transformers import VoxtralProcessor
if processor_cls == VoxtralProcessor:
return VoxtralProcessor.from_pretrained(
cfg.processor_config,
)
from axolotl.utils.mistral import Mistral3Processor
return Mistral3Processor(
@@ -32,7 +52,6 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
cfg.processor_config,
trust_remote_code=cfg.trust_remote_code or False,
tokenizer=tokenizer,
**processor_kwargs,
)
# Attempt to load image size from processor if available

View File

@@ -13,6 +13,7 @@ from transformers import (
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.telemetry.errors import send_errors
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import (
@@ -119,6 +120,7 @@ def modify_tokenizer_files(
return tokenizer_dir
@send_errors
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
"""Load and configure the tokenizer based on the provided config."""

View File

@@ -40,6 +40,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"smollm3",
"granite",
"granitemoe",
"granitemoeshared",
"granitemoehybrid",
"hunyuan_v1_dense",
"hunyuan_v1_moe",
"gpt_oss",
@@ -47,6 +49,12 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"seed_oss",
"lfm2",
"lfm2_moe",
"olmo",
"olmo2",
"olmo3",
"ministral",
"ministral3",
"afmoe",
]

View File

@@ -95,6 +95,7 @@ class ChatTemplatePrompter(Prompter):
add_generation_prompt=False,
images=None,
tools=None,
real_last_index=None,
):
"""
Build a prompt from a conversation.
@@ -114,6 +115,9 @@ class ChatTemplatePrompter(Prompter):
if tools:
chat_template_kwargs["tools"] = tools
if real_last_index:
chat_template_kwargs["real_last_index"] = real_last_index
if self.processor:
if not callable(self.processor):
raise TypeError("Processor must be callable")
@@ -631,11 +635,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
turns_with_empty = turns[:turn_idx] + [empty_turn]
turns_with_content = turns[: turn_idx + 1]
real_last_index = len(turns) - 1
# Generate the conversation up to the turn, with final turn replaced with dummy content
dummy_ids = self.prompter.build_prompt(turns_with_empty, tools=tools) # type: ignore
dummy_ids = self.prompter.build_prompt(
turns_with_empty, tools=tools, real_last_index=real_last_index
) # type: ignore
# Generate the conversation up to the turn, with final turn included
full_ids = self.prompter.build_prompt(turns_with_content, tools=tools) # type: ignore
full_ids = self.prompter.build_prompt(
turns_with_content, tools=tools, real_last_index=real_last_index
) # type: ignore
if not full_ids or not dummy_ids:
LOG.warning(f"Empty template generated for turn {turn_idx}")
@@ -823,6 +833,23 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return None
if isinstance(tools, list):
# Process each tool to handle JSON string parameters
for tool in tools:
if isinstance(tool, dict) and "function" in tool:
function = tool["function"]
if "parameters" in function:
params = function["parameters"]
if isinstance(params, str):
try:
function["parameters"] = json.loads(params)
except json.JSONDecodeError as e:
LOG.error(
f"Error parsing tool parameters as JSON. "
f"Function: {function.get('name', 'unknown')}, "
f"Parameters string: {params!r}, "
f"Error: {e}"
)
raise
return tools
raise ValueError(

View File

@@ -0,0 +1,165 @@
"""Trainer callbacks for reporting runtime metrics at regular intervals."""
import logging
import time
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from axolotl.telemetry.manager import TelemetryManager
from axolotl.telemetry.runtime_metrics import RuntimeMetricsTracker
LOG = logging.getLogger(__name__)
TIME_SINCE_LAST = 60
class TelemetryCallback(TrainerCallback):
"""
Trainer callback for tracking and reporting runtime metrics.
This callback tracks training progress, runtime, and memory usage,
sending telemetry at configurable intervals.
"""
report_interval_steps: int = 100
def __init__(self):
"""Initialize the metrics callback."""
self.tracker = RuntimeMetricsTracker()
self.telemetry_manager = TelemetryManager.get_instance()
self.current_epoch = -1
self.start_time = time.time()
self.last_report_time = None
self.last_report_step = 0
# pylint: disable=unused-argument
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Handle training start."""
self.telemetry_manager.send_event(event_type="train-start")
# pylint: disable=unused-argument
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Handle training end."""
# Send training completion event
self.telemetry_manager.send_event(
event_type="train-end",
properties=self._extract_last_metrics(state)
| self.tracker.metrics.to_dict(),
)
# pylint: disable=unused-argument
def on_epoch_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Handle epoch start."""
self.current_epoch += 1
self.tracker.start_epoch(self.current_epoch)
# pylint: disable=unused-argument
def on_epoch_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Handle epoch end."""
self.tracker.end_epoch(self.current_epoch)
# pylint: disable=unused-argument
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Handle step end."""
step = state.global_step
self.tracker.update_step(step)
# Check if we should report metrics
should_report = (
step % self.report_interval_steps == 0
or step == 1 # Always report first step
or step - self.last_report_step >= self.report_interval_steps
)
if should_report:
current_time = time.time()
if self.last_report_time is not None:
time_since_last_report = current_time - self.last_report_time
else:
time_since_last_report = current_time - self.start_time
steps_since_last_report = step - self.last_report_step
# Only report if enough time has passed
if (
step == 1
or time_since_last_report >= TIME_SINCE_LAST
or steps_since_last_report >= self.report_interval_steps
):
# Calculate steps per second for this interval
if time_since_last_report > 0 and steps_since_last_report > 0:
steps_per_second = steps_since_last_report / time_since_last_report
else:
steps_per_second = 0
# Update memory metrics
self.tracker.update_memory_metrics()
# Prepare metrics to report
metrics = self._extract_last_metrics(state) | {
"step": step,
"epoch": self.current_epoch,
"progress": state.epoch, # Fractional epoch progress
"steps_per_second": steps_per_second,
"elapsed_time": current_time - self.start_time,
"time_since_last_report": time_since_last_report,
}
# Add memory metrics
memory_metrics = self.tracker.get_memory_metrics()
metrics.update({"memory": memory_metrics})
# Send telemetry
self.telemetry_manager.send_event(
event_type="train-progress", properties=metrics
)
# Update last report time and step
self.last_report_time = current_time
self.last_report_step = step
def _extract_last_metrics(self, state: TrainerState) -> dict:
"""Extract last loss, learning_rate, and grad_norm from log history."""
if not state.log_history:
return {"loss": 0, "learning_rate": 0, "grad_norm": 0}
last_log = state.log_history[-1]
return {
"loss": last_log.get("loss", 0),
"learning_rate": last_log.get("learning_rate", 0),
"grad_norm": last_log.get("grad_norm", 0),
}

View File

@@ -0,0 +1,160 @@
"""Telemetry utilities for exception and traceback information."""
import logging
import os
import re
import traceback
from functools import wraps
from inspect import getmodule
from typing import Any, Callable
from axolotl.telemetry.manager import TelemetryManager
LOG = logging.getLogger(__name__)
ERROR_HANDLED = False
def sanitize_stack_trace(stack_trace: str) -> str:
"""
Remove personal information from stack trace messages while keeping Python package codepaths.
This function identifies Python packages by looking for common patterns in virtual environment
and site-packages directories, preserving the package path while removing user-specific paths.
Args:
stack_trace: The original stack trace string.
Returns:
A sanitized version of the stack trace with Python package paths preserved.
"""
# Split the stack trace into lines to process each file path separately
lines = stack_trace.split("\n")
sanitized_lines = []
# Regular expression to find file paths in the stack trace
path_pattern = re.compile(r'(?:File ")(.*?)(?:")')
# Regular expression to identify paths in site-packages or dist-packages
# This matches path segments like "site-packages/package_name" or "dist-packages/package_name"
site_packages_pattern = re.compile(
r"(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
)
# Additional common virtual environment patterns
venv_lib_pattern = re.compile(
r"(?:lib|Lib)[/\\](?:python\d+(?:\.\d+)?[/\\])?(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
)
for line in lines:
# Check if this line contains a file path
path_match = path_pattern.search(line)
if path_match:
full_path = path_match.group(1)
sanitized_path = ""
# Try to match site-packages pattern
site_packages_match = site_packages_pattern.search(full_path)
venv_lib_match = venv_lib_pattern.search(full_path)
if site_packages_match:
# Find the index where the matched pattern starts
idx = full_path.find("site-packages")
if idx == -1:
idx = full_path.find("dist-packages")
# Keep from 'site-packages' onward
if idx >= 0:
sanitized_path = full_path[idx:]
elif venv_lib_match:
# For other virtual environment patterns, find the package directory
match_idx = venv_lib_match.start(1)
if match_idx > 0:
# Keep from the package name onward
package_name = venv_lib_match.group(1)
idx = full_path.rfind(
package_name, 0, match_idx + len(package_name)
)
if idx >= 0:
sanitized_path = full_path[idx:]
# If we couldn't identify a package pattern but path contains 'axolotl'
elif "axolotl" in full_path:
idx = full_path.rfind("axolotl")
if idx >= 0:
sanitized_path = full_path[idx:]
# Apply the sanitization to the line
if sanitized_path:
line = line.replace(full_path, sanitized_path)
else:
# If we couldn't identify a package pattern, just keep the filename
filename = os.path.basename(full_path)
if filename:
line = line.replace(full_path, filename)
else:
line = line.replace(full_path, "")
sanitized_lines.append(line)
return "\n".join(sanitized_lines)
def send_errors(func: Callable) -> Callable:
"""
Decorator to send exception info in a function. If an exception is raised, we send
telemetry containing the stack trace and error message.
If an error occurs in a decorated function that is called by another decorated
function, we'll only send telemetry corresponding to the lower-level function.
Args:
func: Function to decorate.
Returns:
Decorated function.
"""
@wraps(func)
def wrapper(*args, **kwargs) -> Any:
telemetry_manager = TelemetryManager.get_instance()
if not telemetry_manager.enabled:
return func(*args, **kwargs)
try:
return func(*args, **kwargs)
except Exception as exception:
# Only track if we're not already handling an error. This prevents us from
# capturing an error more than once in nested decorated function calls.
global ERROR_HANDLED # pylint: disable=global-statement
if not ERROR_HANDLED:
ERROR_HANDLED = True
# Get function module path
module = getmodule(func)
module_path = (
f"{module.__name__}.{func.__name__}" if module else func.__name__
)
# Get stack trace
stack_trace = "".join(
traceback.format_exception(
type(exception), exception, exception.__traceback__
)
)
stack_trace = sanitize_stack_trace(stack_trace)
# Send error telemetry
telemetry_manager.send_event(
event_type=f"{module_path}-error",
properties={
"exception": str(exception),
"stack_trace": stack_trace,
},
)
raise
return wrapper

View File

@@ -0,0 +1,416 @@
"""Telemetry manager and associated utilities."""
import atexit
import importlib
import logging
import os
import platform
import time
import uuid
from pathlib import Path
from typing import Any
import posthog
import psutil
import torch
import yaml
LOG = logging.getLogger(__name__)
POSTHOG_HOST = "https://app.posthog.com"
POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y"
OPT_OUT_WARNING_SLEEP_SECONDS = 10
OPT_OUT_WARNING = (
"\nTelemetry is now enabled by default to help improve Axolotl. "
"If you'd like to disable it, set AXOLOTL_DO_NOT_TRACK=1 in your environment.\n\n"
"Telemetry data helps us understand:\n"
"- Which features are most used\n"
"- What hardware configurations to prioritize\n"
"- Where users encounter errors\n\n"
"Personally identifiable information (PII) is not collected.\n\n"
"To remove this warning, explicitly set AXOLOTL_DO_NOT_TRACK=0 (enable telemetry) "
"or AXOLOTL_DO_NOT_TRACK=1 (disable telemetry).\n\n"
"For details, see: https://docs.axolotl.ai/docs/telemetry.html\n\n"
f"Sleeping for {OPT_OUT_WARNING_SLEEP_SECONDS}s..."
)
WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml")
# NOTE: Need to keep these up to date with any config schema changes
FIELDS_TO_REDACT = {
"base_model",
"tokenizer_config",
"base_model_config",
"pretraining_dataset", # NOTE: this field may be a string or a dictionary
"resume_from_checkpoint",
"hub_model_id",
}
PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_"}
PATH_INDICATORS = {"path", "dir"}
# pylint: disable=duplicate-code
RELEVANT_PACKAGES = {
"torch",
"transformers",
"trl",
"datasets",
"peft",
"bitsandbytes",
"accelerate",
"optimum",
"deepspeed",
"ray",
"axolotl",
"triton",
"mamba-ssm",
"flash-attn",
"xformers",
"autoawq",
"tokenizers",
"sentencepiece",
"torchao",
"lm_eval",
}
def is_main_process() -> bool:
"""
Check whether we're running in the main process.
Note:
We're using this function instead of `torch.utils.distributed.is_main_process`
causes issues with DeepSpeed world_size since. This function avoids that issue
by checking env vars that are set by various launchers.
Returns:
Whether we're running in the main process.
"""
# If PyTorch distributed is already initialized, use it
if torch.distributed.is_initialized():
return torch.distributed.get_rank() == 0
# Otherwise check environment variables for global rank
# NOTE: need to verify this in SLURM / OpenMPI environments
global_rank = int(
os.environ.get(
"RANK",
os.environ.get(
"GLOBAL_RANK",
os.environ.get(
"SLURM_PROCID",
os.environ.get(
"OMPI_COMM_WORLD_RANK",
"0",
),
),
),
)
)
return global_rank == 0
class TelemetryManager:
"""Manages telemetry collection and transmission"""
_instance = None
_initialized = False
def __new__(cls):
"""
Telemetry manager constructor. Creates the singleton instance of this class if
it doesn't already exist.
"""
if cls._instance is None:
cls._instance = super(TelemetryManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
"""Telemetry manager initializer"""
if self._initialized:
return
self.enabled = self._check_telemetry_enabled()
if self.enabled:
self.run_id = str(uuid.uuid4())
self.whitelist = self._load_whitelist()
try:
self.system_info = self._get_system_info()
except Exception as e: # pylint: disable=broad-exception-caught
LOG.warning(f"Error during system info collection: {e}")
self.system_info = None
self._init_posthog()
# Register shutdown method to flush posthog telemetry
atexit.register(self.shutdown)
self._initialized = True
@classmethod
def get_instance(cls) -> "TelemetryManager":
if cls._instance is None:
cls._instance = TelemetryManager()
return cls._instance
def _check_telemetry_enabled(self) -> bool:
"""
Check if telemetry is enabled based on environment variables. We also check
whether this is the main process (for the distributed setting and to avoid
sending duplicate PostHog events per GPU).
Note: This is enabled by default on an opt-out basis. Set
`AXOLOTL_DO_NOT_TRACK=1` to disable telemetry. For more details, see
https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html.
Returns:
Boolean denoting whether telemetry is enabled or not.
"""
# Parse relevant env vars
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
do_not_track = os.getenv("DO_NOT_TRACK")
# Default to enabled (opt-out model)
if axolotl_do_not_track is None or axolotl_do_not_track.lower() not in (
"0",
"1",
"false",
"true",
):
# Print opt-out info message for main process only
if is_main_process():
LOG.warning(OPT_OUT_WARNING)
time.sleep(OPT_OUT_WARNING_SLEEP_SECONDS)
return True
# Only rank 0 will send telemetry
if not is_main_process():
return False
if do_not_track is None:
do_not_track = "0"
# Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled
enabled = axolotl_do_not_track.lower() not in (
"1",
"true",
) and do_not_track.lower() not in ("1", "true")
return enabled
def _load_whitelist(self) -> dict:
"""Load HuggingFace Hub organization whitelist"""
with open(WHITELIST_PATH, encoding="utf-8") as f:
whitelist = yaml.safe_load(f)
# Send org strings to lowercase since model names are case insensitive
whitelist["organizations"] = {
org.lower() for org in whitelist["organizations"]
}
return whitelist
def _is_whitelisted(self, value: str) -> bool:
"""
Check if model / dataset / etc. org is in whitelist.
Args:
value: Value for one of `axolotl.telemetry.manager.FIELDS_WITH_ORGS`
("base_model", etc.).
Returns:
Boolean indicating whitelist membership.
"""
# NOTE: This membership-checking logic can be improved.
# What happens when a local model path matches a whitelisted org?
parts = value.split("/")
if len(parts) < 2:
return False
org = parts[0]
whitelisted = org.lower() in self.whitelist["organizations"]
return whitelisted
def _init_posthog(self):
"""Initialize PostHog client"""
posthog.api_key = POSTHOG_WRITE_KEY
posthog.project_api_key = POSTHOG_WRITE_KEY
posthog.host = POSTHOG_HOST
def _redact_paths(self, properties: dict[str, Any]) -> dict[str, Any]:
"""
Redact properties to remove any paths, so as to avoid inadvertently collecting
private or personally identifiable information (PII). We also remove
information related to Wandb, MLflow, etc. configuration.
Args:
properties: Dictionary of properties to redact.
Returns:
Properties dictionary with redaction applied.
"""
if not properties:
return {}
def redact_value(value: Any, key: str = "") -> Any:
"""Recursively sanitize values, redacting those with path-like keys"""
if isinstance(key, str) and isinstance(value, str):
# Other redaction special cases
if (
key in FIELDS_TO_REDACT
or any(prefix in key for prefix in PREFIXES_TO_REDACT)
or any(indicator in key.lower() for indicator in PATH_INDICATORS)
):
# Fields with whitelisted orgs don't need to be redacted
if not self._is_whitelisted(value):
return "[REDACTED]"
# Handle nested values
if isinstance(value, dict):
return {k: redact_value(v, k) for k, v in value.items()}
if isinstance(value, list):
return [redact_value(item) for item in value]
return value
# Create new dict with redacted values
redacted = {k: redact_value(v, k) for k, v in properties.items()}
return redacted
def _get_system_info(self) -> dict[str, Any]:
"""Collect system information for various hardware accelerators"""
gpu_info = []
accelerator_type = "none"
# NVIDIA GPUs
if torch.cuda.is_available():
accelerator_type = "cuda"
for i in range(torch.cuda.device_count()):
gpu_info.append(
{
"name": torch.cuda.get_device_name(i),
"memory": torch.cuda.get_device_properties(i).total_memory,
}
)
# AMD GPUs
elif hasattr(torch, "hip") and torch.hip.is_available():
accelerator_type = "hip"
for i in range(torch.hip.device_count()):
gpu_info.append(
{
"name": torch.hip.get_device_name(i),
"memory": (
torch.hip.get_device_properties(i).total_memory
if hasattr(torch.hip, "get_device_properties")
else None
),
}
)
# Apple Silicon
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
accelerator_type = "mps"
gpu_info.append(
{
"name": "Apple Silicon",
# NOTE: this is memory allocated to this process, not total memory
"memory": torch.mps.driver_allocated_memory(),
}
)
# Intel GPUs
elif hasattr(torch, "xpu") and torch.xpu.is_available():
accelerator_type = "xpu"
for i in range(torch.xpu.device_count()):
memory = None
if hasattr(torch.xpu, "get_device_properties"):
memory = torch.xpu.get_device_properties(i).total_memory
gpu_info.append(
{
"name": torch.xpu.get_device_name(i),
"memory": memory,
}
)
# NPUs
elif hasattr(torch, "npu") and torch.npu.is_available():
accelerator_type = "npu"
for i in range(torch.npu.device_count()):
memory = None
if hasattr(torch.npu, "get_device_properties"):
memory = torch.npu.get_device_properties(i).total_memory
gpu_info.append(
{
"name": torch.npu.get_device_name(i),
"memory": memory,
}
)
# Get relevant package versions
installed_packages = {}
for package in RELEVANT_PACKAGES:
try:
version = importlib.metadata.version(package)
installed_packages[f"{package}_version"] = version
except importlib.metadata.PackageNotFoundError:
pass
return {
"os": platform.system(),
"python_version": platform.python_version(),
"cpu_count": psutil.cpu_count(),
"memory_total": psutil.virtual_memory().total,
"accelerator_type": accelerator_type,
"accelerator_count": len(gpu_info),
"accelerator_info": gpu_info,
**installed_packages,
}
def send_event(self, event_type: str, properties: dict[str, Any] | None = None):
"""Send a telemetry event"""
if not self.enabled:
return
if properties is None:
properties = {}
# Sanitize properties to remove PII
properties = self._redact_paths(properties)
# Wrap PostHog errors in try / except to not raise errors during Axolotl usage
try:
# Send event via PostHog
posthog.capture(
distinct_id=self.run_id,
event=event_type,
properties=properties,
disable_geoip=True,
)
except Exception as e: # pylint: disable=broad-exception-caught
LOG.warning(f"Failed to send telemetry event: {e}")
# Additionally, send system info telemetry when loading config.
# NOTE: Is this the best place for this?
if event_type == "config-loaded":
self.send_system_info()
def send_system_info(self):
"""Helper method for sending system info"""
if self.system_info is not None:
self.send_event(event_type="system-info", properties=self.system_info)
def shutdown(self):
"""Ensure all queued events are processed before shutdown"""
if self.enabled:
posthog.shutdown()

View File

@@ -0,0 +1,210 @@
"""Telemetry utilities for runtime and memory metrics."""
import logging
import time
from dataclasses import dataclass, field
from typing import Any
import psutil
import torch
from axolotl.telemetry.manager import TelemetryManager
LOG = logging.getLogger(__name__)
@dataclass
class RuntimeMetrics:
"""Container for runtime metrics to be tracked throughout training."""
# Timing metrics
start_time: float
epoch_start_times: dict[int, float] = field(init=False)
epoch_end_times: dict[int, float] = field(init=False)
# Memory metrics
peak_cpu_memory: int = 0
peak_gpu_memory: dict[int, int] = field(init=False)
# Progress metrics
total_steps: int = 0
current_epoch: int = 0
current_step: int = 0
def __post_init__(self):
"""Initialize empty metric mappings."""
self.epoch_start_times = {}
self.epoch_end_times = {}
self.peak_gpu_memory = {}
@property
def elapsed_time(self) -> float:
"""Calculate total elapsed time in seconds."""
return time.time() - self.start_time
def epoch_time(self, epoch: int) -> float | None:
"""Calculate time taken for a specific epoch in seconds."""
if epoch in self.epoch_start_times and epoch in self.epoch_end_times:
return self.epoch_end_times[epoch] - self.epoch_start_times[epoch]
return None
def average_epoch_time(self) -> float | None:
"""Calculate average time per epoch in seconds."""
completed_epochs = [
epoch for epoch in self.epoch_start_times if epoch in self.epoch_end_times
]
if not completed_epochs:
return None
total_time = 0.0
for epoch in completed_epochs:
epoch_time = self.epoch_time(epoch)
if epoch_time is not None: # Check to avoid mypy warning
total_time += epoch_time
return total_time / len(completed_epochs)
def steps_per_second(self) -> float | None:
"""Calculate average steps per second across all training."""
if self.total_steps == 0 or self.elapsed_time == 0:
return None
return self.total_steps / self.elapsed_time
def to_dict(self) -> dict[str, Any]:
"""Convert metrics to a dictionary for telemetry reporting."""
metrics = {
"total_time_seconds": self.elapsed_time,
"total_steps": self.total_steps,
"steps_per_second": self.steps_per_second(),
"epochs_completed": len(
[
epoch
for epoch in self.epoch_start_times
if epoch in self.epoch_end_times
]
),
"peak_cpu_memory_bytes": self.peak_cpu_memory,
}
# Add per-epoch timing if available
epoch_times: dict[str, float] = {}
for epoch in sorted(self.epoch_end_times.keys()):
time_taken = self.epoch_time(epoch)
if time_taken is not None:
epoch_times[f"epoch_{epoch}_seconds"] = time_taken
if epoch_times:
metrics["epoch_times"] = epoch_times # type: ignore
metrics["average_epoch_time_seconds"] = self.average_epoch_time()
# Add GPU memory metrics if available
if self.peak_gpu_memory:
gpu_metrics: dict[str, int] = {}
for gpu_id, memory in self.peak_gpu_memory.items():
gpu_metrics[f"gpu_{gpu_id}_peak_memory_bytes"] = memory
metrics["gpu_memory"] = gpu_metrics # type: ignore
return metrics
class RuntimeMetricsTracker:
"""Tracker for runtime metrics during training."""
update_interval = 100
def __init__(self):
"""Initialize the runtime metrics tracker."""
self.metrics = RuntimeMetrics(start_time=time.time())
self.telemetry_manager = TelemetryManager.get_instance()
self._process = psutil.Process()
def start_epoch(self, epoch: int):
"""Record the start of a new epoch."""
self.metrics.current_epoch = epoch
self.metrics.epoch_start_times[epoch] = time.time()
self.update_memory_metrics()
def end_epoch(self, epoch: int):
"""Record the end of an epoch."""
self.metrics.epoch_end_times[epoch] = time.time()
def update_step(self, step: int):
"""Update the current step count."""
self.metrics.current_step = step
self.metrics.total_steps += 1
# Periodically update memory metrics
if step % self.update_interval == 0:
self.update_memory_metrics()
def _get_allocated_memory(self) -> dict[int, int]:
"""
Helper function for getting accelerator-agnostic allocated memory.
Returns:
A dictionary mapping device IDs to allocated memory in bytes
"""
memory_used: dict[int, int] = {}
# NVIDIA GPUs
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
memory_used[i] = torch.cuda.memory_allocated(i)
# AMD GPUs
elif hasattr(torch, "hip") and torch.hip.is_available():
for i in range(torch.hip.device_count()):
if hasattr(torch.hip, "memory_allocated"):
memory_used[i] = torch.hip.memory_allocated(i)
# Apple Silicon
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
# MPS doesn't have per-device memory stats since there's only one device
if hasattr(torch.mps, "current_allocated_memory"):
memory_used[0] = torch.mps.current_allocated_memory()
# Intel GPUs
elif hasattr(torch, "xpu") and torch.xpu.is_available():
for i in range(torch.xpu.device_count()):
if hasattr(torch.xpu, "memory_allocated"):
memory_used[i] = torch.xpu.memory_allocated(i)
# NPUs
elif hasattr(torch, "npu") and torch.npu.is_available():
for i in range(torch.npu.device_count()):
if hasattr(torch.npu, "memory_allocated"):
memory_used[i] = torch.npu.memory_allocated(i)
return memory_used
def update_memory_metrics(self):
"""Update peak memory usage metrics."""
# CPU memory
cpu_memory = self._process.memory_info().rss
self.metrics.peak_cpu_memory = max(self.metrics.peak_cpu_memory, cpu_memory)
# GPU memory (if available)
memory_used = self._get_allocated_memory()
for i, memory in memory_used.items():
self.metrics.peak_gpu_memory[i] = max(
self.metrics.peak_gpu_memory.get(i, 0), memory
)
def get_memory_metrics(self) -> dict[str, Any]:
"""Get the current memory metrics as a dictionary."""
memory_metrics = {
"cpu_memory_bytes": self._process.memory_info().rss,
"peak_cpu_memory_bytes": self.metrics.peak_cpu_memory,
}
# GPU memory (if available)
memory_used = self._get_allocated_memory()
for i, memory in memory_used.items():
memory_metrics[f"gpu_{i}_memory_bytes"] = memory
memory_metrics[f"gpu_{i}_peak_memory_bytes"] = (
self.metrics.peak_gpu_memory.get(i, 0)
)
return memory_metrics

View File

@@ -0,0 +1,33 @@
organizations:
- "axolotl-ai-co"
- "meta-llama"
- "huggingface"
- "nvidia"
- "facebook"
- "google"
- "microsoft"
- "deepseek-ai"
- "HuggingFaceTB"
- "mistralai"
- "Qwen"
- "unsloth"
- "NousResearch"
- "allenai"
- "amd"
- "tiiuae"
- "tencent"
- "zai-org"
- "openai"
- "ibm-granite"
- "arcee-ai"
- "swiss-ai"
- "CohereForAI"
- "deepcogito"
- "THUDM"
- "ai21labs"
- "LiquidAI"
- "canopylabs"
- "state-spaces"
- "mistral-community"
- "llava-hf"
- "ByteDance-Seed"

View File

@@ -31,6 +31,8 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
)
from axolotl.integrations.base import PluginManager
from axolotl.loaders import ModelLoader, load_processor, load_tokenizer
from axolotl.telemetry.errors import send_errors
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
@@ -45,6 +47,9 @@ if typing.TYPE_CHECKING:
LOG = get_logger(__name__)
TELEMETRY_MANAGER = TelemetryManager.get_instance()
PLUGIN_MANAGER = PluginManager.get_instance()
def setup_model_and_tokenizer(
cfg: DictDefault,
@@ -62,7 +67,10 @@ def setup_model_and_tokenizer(
`None`), and processor (if multimodal, else `None`).
"""
# Load tokenizer
LOG.debug(f"Loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
# Load processor for multimodal models if needed
@@ -78,6 +86,14 @@ def setup_model_and_tokenizer(
if model.generation_config is not None:
model.generation_config.do_sample = True
TELEMETRY_MANAGER.send_event(
event_type="model-load", properties=model.config.to_dict()
)
if peft_config:
TELEMETRY_MANAGER.send_event(
event_type="peft-config-load", properties=peft_config.to_dict()
)
# Apply freezing if specified
if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters)
@@ -196,8 +212,7 @@ def execute_training(
LOG.info("Starting trainer...")
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train(cfg, trainer.model)
PLUGIN_MANAGER.post_train(cfg, trainer.model)
def save_trained_model(
@@ -521,9 +536,7 @@ def setup_model_and_trainer(
model_ref=model_ref,
peft_config=peft_config,
)
plugin_manager = PluginManager.get_instance()
plugin_manager.post_trainer_create(cfg, trainer)
PLUGIN_MANAGER.post_trainer_create(cfg, trainer)
if cfg.use_ray:
try:
@@ -545,6 +558,7 @@ def setup_model_and_trainer(
)
@send_errors
def train(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
@@ -595,5 +609,6 @@ def train(
create_model_card(cfg, trainer)
if not cfg.use_ray:
cleanup_distributed()
PLUGIN_MANAGER.post_train(cfg, model)
return model, tokenizer, trainer

View File

@@ -41,14 +41,27 @@ def get_pytorch_version() -> tuple[int, int, int]:
def set_pytorch_cuda_alloc_conf():
"""Set up CUDA allocation config if using PyTorch >= 2.2"""
"""Set up CUDA allocation config"""
torch_version = torch.__version__.split(".")
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
if torch_major == 2 and torch_minor >= 2:
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
"expandable_segments:True,roundup_power2_divisions:16"
)
config_value = "expandable_segments:True,roundup_power2_divisions:16"
if (
torch_major == 2
and torch_minor >= 9
and os.getenv("PYTORCH_ALLOC_CONF") is None
):
os.environ["PYTORCH_ALLOC_CONF"] = config_value
elif (
torch_major == 2
and torch_minor >= 2
and os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None
):
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = config_value
def set_misc_env():
if os.getenv("XFORMERS_IGNORE_FLASH_VERSION_CHECK") is None:
os.environ["XFORMERS_IGNORE_FLASH_VERSION_CHECK"] = "1"
def get_not_null(value, default=None):

View File

@@ -0,0 +1,132 @@
from pathlib import Path
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from axolotl.utils.distributed import (
barrier,
is_distributed,
is_main_process,
)
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
DEFAULT_TRIGGER_FILENAME = "axolotl_checkpoint.save"
class DynamicCheckpointCallback(TrainerCallback):
"""
Callback to save checkpoints on-demand during training via:
1. File-based trigger (works everywhere, rank 0 checks file)
Thread-safe for multi-GPU distributed training.
Usage:
# File-based:
touch /path/to/output_dir/axolotl_checkpoint.save
"""
def _get_config_value(self, config, key, default=None):
"""Helper to get config value from dict or object."""
if isinstance(config, dict):
return config.get(key, default)
return getattr(config, key, default)
def __init__(self, cfg):
self.cfg = cfg
if not cfg.dynamic_checkpoint or not cfg.dynamic_checkpoint.enabled:
self.enabled = False
return
self.enabled = True
dc_config = cfg.dynamic_checkpoint
trigger_file_path = self._get_config_value(dc_config, "trigger_file_path")
self.trigger_filename = (
trigger_file_path if trigger_file_path else DEFAULT_TRIGGER_FILENAME
)
check_interval = self._get_config_value(dc_config, "check_interval")
self.check_interval = check_interval if check_interval is not None else 100
self.should_save_checkpoint = False
LOG.info(
f"Dynamic checkpoint enabled. To trigger checkpoint save:\n"
f" • File: touch {cfg.output_dir}/{self.trigger_filename}\n"
f" • Check interval: every {self.check_interval} steps",
main_process_only=True,
)
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
) -> TrainerControl:
"""
Check for checkpoint triggers at the end of each step.
ONLY rank 0 checks the file, then all ranks synchronize.
"""
if not self.enabled:
return control
trigger_detected = False
if state.global_step % self.check_interval == 0:
if is_main_process():
trigger_path = Path(args.output_dir) / self.trigger_filename
if trigger_path.exists():
trigger_detected = True
try:
trigger_path.unlink() # Delete the trigger file
LOG.info(
f"Dynamic checkpoint triggered via file '{self.trigger_filename}' "
f"at step {state.global_step}",
main_process_only=True,
)
except OSError as exc:
LOG.warning(
f"Failed to delete trigger file: {exc}",
main_process_only=True,
)
if self.should_save_checkpoint:
trigger_detected = True
self.should_save_checkpoint = False # Reset flag
if is_distributed():
import torch
import torch.distributed as dist
device = getattr(
args,
"device",
torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
trigger_tensor = torch.tensor(
1 if trigger_detected else 0,
dtype=torch.long,
device=device,
)
dist.broadcast(trigger_tensor, src=0)
trigger_detected = bool(trigger_tensor.item())
barrier()
if trigger_detected:
control.should_save = True
LOG.info(
f"Saving dynamic checkpoint at step {state.global_step}",
main_process_only=True,
)
return control

View File

@@ -0,0 +1,126 @@
{%- if not skip_think is defined %}
{%- set skip_think = true %}
{%- endif %}
{%- set role_indicators = {
'user': '[|user|]\n',
'assistant': '[|assistant|]\n',
'system': '[|system|]\n',
'tool': '[|tool|]\n'
} %}
{%- set end_of_turn = '[|endofturn|]\n' %}
{%- macro available_tools(tools) %}
{{- "# Available Tools" }}
{{- "\nYou can use none, one, or multiple of the following tools by calling them as functions to help with the users query." }}
{{- "\nHere are the tools available to you in JSON format within <tool> and </tool> tags:\n" }}
{%- for tool in tools %}
{{- "<tool>" }}
{{- tool | tojson(ensure_ascii=False) | safe }}
{{- "</tool>\n" }}
{%- endfor %}
{{- "\nFor each function call you want to make, return a JSON object with function name and arguments within <tool_call> and </tool_call> tags, like:" }}
{{- "\n<tool_call>{\"name\": function_1_name, \"arguments\": {argument_1_name: argument_1_value, argument_2_name: argument_2_value}}</tool_call>" }}
{{- "\n<tool_call>{\"name\": function_2_name, \"arguments\": {...}}</tool_call>\n..." }}
{{- "\nNote that if no argument name is specified for a tool, you can just print the argument value directly, without the argument name or JSON formatting." }}
{%- endmacro %}
{%- set ns = namespace(last_query_index = messages|length - 1) %}
{%- for message in messages %}
{%- if message.role == "user" and message.content is string %}
{%- set ns.last_query_index = loop.index0 -%}
{%- endif %}
{%- endfor %}
{%- for i in range(messages | length) %}
{%- set msg = messages[i] %}
{%- set role = msg.role %}
{%- if role not in role_indicators %}
{{- raise_exception('Unknown role: ' ~ role) }}
{%- endif %}
{# ---- Case A: If the first message is "system", handle it here alone (without continue) ---- #}
{%- if i == 0 and role == 'system' %}
{{- role_indicators['system'] }}
{{- msg.content }}
{%- if tools is defined and tools %}
{{- "\n\n" }}{{- available_tools(tools) }}
{%- endif %}
{{- end_of_turn -}}
{%- else %}
{# ---- Case B: If the first message is tools instead of system, inject the system tools preamble ---- #}
{%- if i == 0 and tools is defined and tools %}
{{- role_indicators['system'] }}
{{- available_tools(tools) }}
{{- end_of_turn -}}
{%- endif %}
{%- endif %}
{%- if role == 'assistant' %}
{{- role_indicators['assistant'] }}
{%- if msg.content %}
{%- if "</think>" in msg.content %}
{%- set content = msg.content.split('</think>')[-1].strip() %}
{%- set reasoning_content = msg.content.split('</think>')[0].strip() %}
{%- if reasoning_content.startswith("<think>") %}
{%- set reasoning_content = reasoning_content[7:].strip() %}
{%- endif %}
{%- else %}
{%- set content = msg.content %}
{%- endif %}
{%- if msg.reasoning_content %}
{%- set reasoning_content = msg.reasoning_content %}
{%- endif %}
{%- if (not skip_think and loop.last) and reasoning_content is defined %}
{{- "<think>\n" }}
{{- reasoning_content}}
{{- "\n</think>\n\n" }}
{%- else %}
{{- "<think>\n\n</think>\n\n" }}
{%- endif %}
{{- content }}
{%- endif %}
{%- if msg.tool_calls %}
{%- if msg.content %}
{{- "\n" }}
{%- else %}
{{- "<think>\n\n</think>\n\n" }}
{%- endif %}
{%- for tool_call in msg.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{%- if tool_call.arguments is defined %}
{%- set arguments = tool_call.arguments %}
{%- elif tool_call.parameters is defined %}
{%- set arguments = tool_call.parameters %}
{%- else %}
{{- raise_exception('arguments or parameters are mandatory: ' ~ tool_call) }}
{%- endif %}
{{- "<tool_call>" }}{"name": "{{- tool_call.name }}", "arguments": {{ arguments | tojson(ensure_ascii=False) | safe }}}{{- "</tool_call>" }}
{%- if not loop.last %}
{{- "\n" }}
{%- endif %}
{%- endfor %}
{%- endif %}
{{- end_of_turn -}}
{%- elif role == "tool" %}
{%- if i == 0 or messages[i - 1].role != "tool" %}
{{- role_indicators['tool'] }}
{%- endif %}
{%- if msg.content is defined %}
{{- "<tool_result>" }}{"result": {{ msg.content | tojson(ensure_ascii=False) | safe }}}{{- "</tool_result>" }}
{%- endif %}
{%- if loop.last or messages[i + 1].role != "tool" %}
{{- end_of_turn -}}
{%- else %}
{{- "\n" }}
{%- endif %}
{%- else %}
{{- role_indicators[role] }}
{{- msg.content }}
{{- end_of_turn -}}
{%- endif %}
{% endfor %}
{%- if add_generation_prompt %}
{{- role_indicators['assistant'] }}
{%- if enable_thinking is defined and enable_thinking is true %}
{{- "<think>\n" }}
{%- else %}
{{- "<think>\n\n</think>\n\n" }}
{%- endif %}
{%- endif %}

View File

@@ -15,6 +15,12 @@
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{#- Determine the real last index: use provided value or default to messages length - 1 #}
{%- if real_last_index is defined and real_last_index is not none %}
{%- set ns.real_last_index = real_last_index %}
{%- else %}
{%- set ns.real_last_index = messages|length - 1 %}
{%- endif %}
{%- for message in messages[::-1] %}
{%- set index = (messages|length - 1) - loop.index0 %}
{%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
@@ -37,7 +43,7 @@
{%- endif %}
{%- endif %}
{%- if loop.index0 > ns.last_query_index %}
{%- if loop.last or (not loop.last and reasoning_content) %}
{%- if loop.index0 == ns.real_last_index or (loop.index0 != ns.real_last_index and reasoning_content) %}
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}

View File

@@ -203,6 +203,7 @@ def wrap_streaming_dataset(
max_seq_length=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
multipack_attn=multipack_attn,
bin_size=cfg.sample_packing_bin_size,
)
# Set this to 1 so downstream data_loader doesn't try to increase the batch size
@@ -254,6 +255,7 @@ def encode_packed_streaming(
collate_fn,
ds_wrapper: Callable,
examples: Dict[str, List],
bin_size: int,
max_seq_length: int = 2048,
batch_size: int = 4,
multipack_attn: Optional[bool] = True,
@@ -278,6 +280,7 @@ def encode_packed_streaming(
batch_max_len=batch_size * max_seq_length,
drop_last=True,
num_processes=1,
bin_size=bin_size,
)
chunked_data = defaultdict(list)

View File

@@ -180,15 +180,20 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def handle_long_seq_in_dataset(
dataset: Dataset, sequence_len: int, cfg: DictDefault
) -> Dataset:
"""Remove sequences longer than configured maximum from dataset.
Args:
dataset: Dataset to filter.
sequence_len: Maximum length for sequences to keep
cfg: Dictionary mapping `axolotl` config keys to values.
"""
Remove or truncate sequences that exceed the configured maximum length from a dataset.
Parameters:
dataset (Dataset): Dataset to process; if it lacks an "input_ids" column or is streaming, it is returned unchanged.
sequence_len (int): Maximum allowed sequence length; sequences longer than this are either removed or truncated.
cfg (DictDefault): Configuration object with keys:
- excess_length_strategy: "drop", "truncate", or "raise" — determines how to handle overlong sequences.
- min_sample_len: minimum allowed sequence length (used when truncating or dropping).
- dataset_num_proc: number of processes to use for non-streaming datasets.
- is_preprocess: when true, bypasses cached preprocessing during filtering.
Returns:
Filtered dataset with long sequences removed.
Dataset: The input dataset with sequences longer than `sequence_len` removed or truncated according to `cfg`.
"""
if (
hasattr(dataset, "column_names")
@@ -206,10 +211,13 @@ def handle_long_seq_in_dataset(
)
return dataset
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
drop_long = functools.partial(
drop_long_seq,
sequence_len=sequence_len,
min_sequence_len=cfg.min_sample_len,
raise_on_drop=excess_length_strategy == "raise",
)
with contextlib.suppress(AttributeError):
@@ -230,7 +238,6 @@ def handle_long_seq_in_dataset(
if filter_map_kwargs:
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
if excess_length_strategy == "truncate":
process_fn = functools.partial(
truncate_long_seq,
@@ -259,4 +266,4 @@ def handle_long_seq_in_dataset(
)
LOG.warning(f"{action.title()} {dropped} samples from dataset")
return dataset
return dataset

View File

@@ -30,6 +30,7 @@ class Mistral3Processor(ProcessorMixin):
Wraps HFMistralTokenizer and adds image processing capabilities.
"""
# TODO(nano): This should be removed in transformers V5
attributes = ["tokenizer"]
tokenizer_class = "HFMistralTokenizer"

View File

@@ -80,6 +80,9 @@ class HFMistralTokenizer(MistralCommonTokenizer):
) -> str | list[int]:
"""Patched fn to handle setting serving mode, continue_final_message, remove chat_template and add_generation_prompt kwarg"""
# pop unnecessary kwarg for mistral
kwargs.pop("real_last_index", None)
try:
if add_generation_prompt:
self._set_mode(ValidationMode.serving)
@@ -218,3 +221,10 @@ class HFMistralTokenizer(MistralCommonTokenizer):
model_input_names=model_input_names,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
def save_pretrained(self, *args, **kwargs) -> tuple[str, ...]:
"""
Patches to remove save_jinja_files from being passed onwards.
"""
kwargs.pop("save_jinja_files", None)
return super().save_pretrained(*args, **kwargs)

View File

@@ -260,12 +260,12 @@ class MultipackBatchSampler(BatchSampler):
batch_size: int, # Number of bins per batch
batch_max_len: int, # Maximum sequence length (bin capacity)
lengths: np.ndarray, # Sequence lengths
bin_size: int, # The max number of samples that can be packed in a single bin
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
num_count_samples: int = 4, # Number of times to estimate batch count
sequential: bool = False, # Whether to use sequential packing
group_size: int = 100_000, # Size of groups for parallel packing
bin_size: int = 200, # The max number of samples that can be packed in a single bin
num_processes: int | None = None, # Number of processes for parallel packing
safe_mode: bool = True, # Conservative packing to prevent training instability
mp_start_method: str = "fork",
@@ -343,7 +343,7 @@ class MultipackBatchSampler(BatchSampler):
lengths,
bin_capacity=self.batch_max_len,
group_size=self.group_size,
bin_size=self.bin_size,
bin_size=self.bin_size or self.batch_max_len,
num_processes=min(4, num_processes) if num_processes else 4,
safe_mode=self.safe_mode,
mp_start_method=self.mp_start_method,

View File

@@ -23,6 +23,7 @@ from axolotl.utils.schemas.datasets import (
StepwiseSupervisedDataset,
)
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
from axolotl.utils.schemas.fsdp import FSDPConfig
from axolotl.utils.schemas.integrations import (
@@ -141,6 +142,13 @@ class AxolotlInputConfig(
default=None,
json_schema_extra={"description": "Reward modelling: `True` or `False`"},
)
dynamic_checkpoint: DynamicCheckpointConfig | None = Field(
default=None,
json_schema_extra={
"description": "Configuration for dynamic checkpointing (trigger by file or signal). "
"Set 'enabled: true' to activate this feature."
},
)
process_reward_model: bool | None = Field(
default=None,
json_schema_extra={
@@ -1061,7 +1069,7 @@ class AxolotlInputConfig(
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate GPU capabilities with the configured options"""
"""Wrapper to valdiate GPU capabilities with the configured options"""
capabilities: GPUCapabilities
env_capabilities: EnvCapabilities

View File

@@ -0,0 +1,31 @@
"""Schema for dynamic checkpoint configuration."""
from pydantic import BaseModel, Field
class DynamicCheckpointConfig(BaseModel):
"""Configuration for dynamic checkpoint triggering during training."""
enabled: bool = Field(
default=False,
json_schema_extra={
"description": "Enable dynamic checkpoint triggering during training. "
"Create a file 'axolotl_checkpoint.save' in the configured `output_dir` to trigger. "
},
)
check_interval: int = Field(
default=10,
ge=1,
json_schema_extra={
"description": "Check for trigger file every N steps (reduces I/O overhead). "
"Default: 100"
},
)
trigger_file_path: str = Field(
default="",
json_schema_extra={
"description": "Custom trigger filename (optional). "
"If not specified, defaults to 'axolotl_checkpoint.save'. "
"Specify a filename (not a full path) to override the default."
},
)

View File

@@ -58,6 +58,7 @@ class ChatTemplate(str, Enum):
falcon_h1 = "falcon_h1"
tokenizer_default = "tokenizer_default"
exaone = "exaone"
exaone4 = "exaone4"
metharme = "metharme"
pixtral = "pixtral"
llava = "llava"

View File

@@ -100,6 +100,21 @@ class LoraConfig(BaseModel):
)
},
)
peft_ensure_weight_tying: bool | None = Field(
default=None,
json_schema_extra={
"description": (
"Whether to tie adapter weights for tied model weights. "
"See https://github.com/huggingface/peft/issues/2864"
)
},
)
peft_autocast_adapter_dtype: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to upcast the LoRA adapter to fp32. This is enabled by default in PEFT."
},
)
qlora_sharded_model_loading: bool | None = Field(
default=False,

View File

@@ -173,3 +173,9 @@ class TRLConfig(BaseModel):
"description": "Enable sleep mode for vLLM to offload VRAM when idle"
},
)
rollout_func: str | None = Field(
default=None,
json_schema_extra={
"description": "Path to custom rollout function. Must be importable from current dir."
},
)

View File

@@ -201,16 +201,33 @@ def add_pose_position_ids(
def add_length(sample):
"""
Set the "length" field on a sample to the number of input tokens.
Parameters:
sample (Mapping-like): A sample containing an "input_ids" sequence.
Returns:
sample (dict-like): The same sample with "length" set to len(sample["input_ids"]).
"""
sample["length"] = len(sample["input_ids"])
return sample
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False):
"""
Drop samples whose sequence length is either too long (> sequence_len)
or too short (< min_sequence_len).
Works for both single-example (list[int]) or batched (list[list[int]]).
Return whether a sample (single or batched) should be kept based on sequence length constraints.
Determines if each sequence's length falls within [min_sequence_len, sequence_len]. Supports a single example (list[int]) or a batch (list[list[int]]). If the sample's "input_ids" is empty, the sample is treated as dropped. When raise_on_drop is True, encountering any sequence longer than sequence_len raises a ValueError.
Parameters:
sample (dict): A mapping containing "input_ids" with either a single sequence or a batch of sequences.
sequence_len (int): Maximum allowed sequence length (inclusive).
min_sequence_len (int): Minimum allowed sequence length (inclusive).
raise_on_drop (bool): If True, raise ValueError when a sequence exceeds sequence_len.
Returns:
bool or list[bool]: For a single example, returns True if its length is within the bounds, False otherwise. For a batch, returns a list of booleans indicating which sequences should be kept.
"""
min_sequence_len = min_sequence_len or 2
@@ -225,12 +242,20 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
if isinstance(input_ids[0], int):
# Single example (input_ids is a list of int)
length = len(input_ids)
if raise_on_drop and length > sequence_len:
raise ValueError(
f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}."
)
return min_sequence_len <= length <= sequence_len
# Batched (input_ids is a list of lists)
results = []
for seq in input_ids:
length = len(seq)
if raise_on_drop and length > sequence_len:
raise ValueError(
f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}."
)
results.append(min_sequence_len <= length <= sequence_len)
return results
@@ -715,4 +740,4 @@ def setup_trainer(
trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset
return trainer_builder.build(total_num_steps)
return trainer_builder.build(total_num_steps)

View File

@@ -1,6 +1,4 @@
"""
shared pytest fixtures
"""
"""Shared pytest fixtures"""
import functools
import importlib
@@ -582,3 +580,9 @@ def test_load_fixtures(
download_llama2_model_fixture,
):
pass
@pytest.fixture(autouse=True)
def disable_telemetry(monkeypatch):
monkeypatch.setenv("AXOLOTL_DO_NOT_TRACK", "1")
yield

View File

@@ -396,10 +396,10 @@ def rand_reward_func(prompts, completions) -> list[float]:
),
("orpo_cfg", None), # don't use fixture for orpo to use smaller split
("kto_cfg", None), # no fixture for kto
(
"simpo_cfg",
"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
),
# (
# "simpo_cfg",
# "dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
# ),
],
)
def test_custom_optimizer_cls_and_kwargs(

View File

@@ -2,6 +2,8 @@
Simple end-to-end test for Liger integration
"""
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
@@ -62,7 +64,11 @@ class LigerIntegrationTestCase:
check_model_output_exists(temp_dir, cfg)
@require_torch_2_4_1
def test_llama_w_flce(self, temp_dir):
@pytest.mark.parametrize(
"liger_use_token_scaling",
[True, False],
)
def test_llama_w_flce(self, temp_dir, liger_use_token_scaling):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -74,6 +80,7 @@ class LigerIntegrationTestCase:
"liger_glu_activation": True,
"liger_cross_entropy": False,
"liger_fused_linear_cross_entropy": True,
"liger_use_token_scaling": liger_use_token_scaling,
"sequence_len": 1024,
"val_set_size": 0.05,
"special_tokens": {

Some files were not shown because too many files have changed in this diff Show More