Compare commits

..

104 Commits

Author SHA1 Message Date
Dan Saunders
8564961423 fix compile 2025-09-19 13:59:57 -04:00
Dan Saunders
ce21da9177 fix compile 2025-09-19 13:55:54 -04:00
Dan Saunders
b5dc58373f fix compile 2025-09-19 13:52:42 -04:00
Dan Saunders
7327144344 compile 2025-09-19 13:41:12 -04:00
Dan Saunders
fb11f696e9 bench sweep 2025-09-19 13:24:40 -04:00
Dan Saunders
105c817b0b default fix 2025-09-19 16:59:20 +00:00
Dan Saunders
64345e7707 recurse fix 2025-09-19 12:58:58 -04:00
Dan Saunders
0f8b921399 contig 2025-09-19 12:47:53 -04:00
Dan Saunders
336616d659 defaults 2025-09-19 16:45:39 +00:00
Dan Saunders
d2f1e23bcd fix 2025-09-19 12:45:18 -04:00
Dan Saunders
42aadc5069 bench fix 2025-09-19 12:34:08 -04:00
Dan Saunders
1e7302d30a bench fix 2025-09-19 12:20:35 -04:00
Dan Saunders
63544ce709 fix 2025-09-19 11:34:27 -04:00
Dan Saunders
3bfed0aac8 shared expert detection 2025-09-19 11:24:26 -04:00
Dan Saunders
bfc848f81d bits and pieces 2025-09-19 02:12:57 +00:00
Dan Saunders
abe1cad6bc another bench 2025-09-18 13:45:19 -04:00
Dan Saunders
354389caef torchtitan bench 2025-09-18 13:29:20 -04:00
Dan Saunders
efcd032fce yet another refactor 2025-09-18 13:03:28 -04:00
Dan Saunders
7500641601 yet another refactor 2025-09-18 12:47:15 -04:00
Dan Saunders
0295df5bca precompute fuse 2025-09-18 12:10:46 -04:00
Dan Saunders
b39ef54833 combine mult 2025-09-18 12:08:03 -04:00
Dan Saunders
ad4cd39bcd remove contig 2025-09-18 11:55:15 -04:00
Dan Saunders
5c197275ad inplace 2025-09-18 11:51:17 -04:00
Dan Saunders
19c91e3675 refactor 2025-09-18 11:44:21 -04:00
Dan Saunders
2a176e4923 fix 2025-09-18 11:29:33 -04:00
Dan Saunders
7d867de9b2 refactor 2025-09-18 11:23:15 -04:00
Dan Saunders
01b6792c2e refactor 2025-09-18 11:20:08 -04:00
Dan Saunders
bbf1f14ca4 dtype issues 2025-09-17 23:52:18 +00:00
Dan Saunders
c6878beb7d simplify 2025-09-17 19:15:34 -04:00
Dan Saunders
e62979d11d fix 2025-09-17 18:53:07 -04:00
Dan Saunders
d57b9c67c2 log 2025-09-17 18:52:27 -04:00
Dan Saunders
eaaf16aa00 cumulative offsets 2025-09-17 18:45:15 -04:00
Dan Saunders
f3b953e222 fix? 2025-09-17 18:42:10 -04:00
Dan Saunders
7935dc0911 dtype fix 2025-09-17 18:36:22 -04:00
Dan Saunders
d2b49b2670 error msg 2025-09-17 18:29:30 -04:00
Dan Saunders
b5cb345ca4 fix test 2025-09-17 18:24:00 -04:00
Dan Saunders
03d4c2683e fix perf degradation 2025-09-17 18:20:37 -04:00
Dan Saunders
fd87eed501 minify 2025-09-17 16:42:35 -04:00
Dan Saunders
129db67705 fix 2025-09-17 16:24:29 -04:00
Dan Saunders
38b890a36b fix 2025-09-17 16:16:41 -04:00
Dan Saunders
180920c7bf simplify 2025-09-17 19:49:18 +00:00
Dan Saunders
d024048d74 logs + fix 2025-09-17 14:50:49 -04:00
Dan Saunders
98dc945838 fix 2025-09-17 14:42:53 -04:00
Dan Saunders
108600cd69 update config 2025-09-17 14:36:24 -04:00
Dan Saunders
0e9387c395 fix 2025-09-17 14:35:36 -04:00
Dan Saunders
db61e0d4ff fix 2025-09-17 14:26:25 -04:00
Dan Saunders
51e565f60a logs 2025-09-17 14:15:51 -04:00
Dan Saunders
c774dd0409 refactor + fix 2025-09-17 14:01:39 -04:00
Dan Saunders
7289e0cb55 more logs 2025-09-17 13:44:26 -04:00
Dan Saunders
8d483c11f7 more logs 2025-09-17 13:44:26 -04:00
Dan Saunders
9c1829cf57 more logs 2025-09-17 13:44:26 -04:00
Dan Saunders
135b09d1de logs, qwen2 support 2025-09-17 13:44:26 -04:00
Dan Saunders
de4344a56e patch 2025-09-17 13:44:26 -04:00
Dan Saunders
7d572b58d1 just grouped_mm for now 2025-09-17 13:44:26 -04:00
Dan Saunders
773d7e4291 update 2025-09-17 13:44:26 -04:00
Dan Saunders
fef47a5b7c hardening 2025-09-17 13:44:26 -04:00
Dan Saunders
f6ed8ddc01 fix 2025-09-17 13:44:26 -04:00
Dan Saunders
556d6448fe fix 2025-09-17 13:44:26 -04:00
Dan Saunders
5c2229721d diag 2025-09-17 13:44:26 -04:00
Dan Saunders
d7de6b0e96 grouped_mm 2025-09-17 13:44:26 -04:00
Dan Saunders
3c6648678f numerics 2025-09-17 13:44:26 -04:00
Dan Saunders
5b19a1ea9c improve 2025-09-17 13:44:26 -04:00
Dan Saunders
cfefad1eea fix 2025-09-17 13:44:26 -04:00
Dan Saunders
125e7b5fe6 fast path 2025-09-17 13:44:26 -04:00
Dan Saunders
479b6144df tflops 2025-09-17 13:44:26 -04:00
Dan Saunders
68da65cba2 update 2025-09-17 13:44:26 -04:00
Dan Saunders
0d689bb421 cache, example 2025-09-17 13:44:26 -04:00
Dan Saunders
43ada1278a moe kernels init scaffold 2025-09-17 13:44:26 -04:00
Dan Saunders
4065bc14c6 Debug log, logging improvements (#3159)
* simplify logging

* remove comment

* progress on debug.log

* add debug-level logger for file log

* simplify

* case insensitivity; 3rd party logging improvements

* simplify

* fix

* tests

* lint

* nits

* nit

* Update tests/test_utils_tee.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* cleanup / comments

* fix

* oops

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
2025-09-17 13:27:03 -04:00
salman
e5c427f6de qat doc updates (#3162) [skip-ci] 2025-09-17 10:38:15 +01:00
Wing Lian
86d6ee7c05 upgrade trl and accelerate (#3161)
* upgrade trl==0.23.0

* upgrade accelerate patch fix

* add hints when using gradient_checkpointing with DPO

* set gradient-checpointing properly
2025-09-16 14:53:01 -04:00
Wing Lian
d4cff1b7bb improve setting of NCCL_P2P_DISABLE on runpod (#3132) [skip ci]
* improve setting of NCCL_P2P_DISABLE on runpod

* use recs from review
2025-09-16 14:52:45 -04:00
Wing Lian
1ef6c196f7 setup env vars for ray train for FSDP (#3130) [skip ci] 2025-09-16 14:52:29 -04:00
salman
58d67bf98d Migrate QAT API; fix axolotl quantize for QAT-ed models; add NVFP4 (#3107) 2025-09-12 10:55:50 +01:00
salman
0401a15888 SEO go brrr (#3153) [skip-ci] 2025-09-12 10:55:11 +01:00
NanoCode012
fcfc13d710 feat(doc): update thinking and chat_template notes (#3114) [skip ci]
* feat: update thinking and chat_template notes

* fix: grammar
2025-09-12 14:45:18 +07:00
salman
9406c0c488 log before eval step (#3148) [skip-ci] 2025-09-11 11:19:30 +01:00
Dan Saunders
1b53c49e1a text diffusion training plugin (#3067)
* diffusion training plugin

* cleanup

* nits

* fixes + improvements

* add back in reinit_weights (clobbered?); masking / pretrain fixes

* nits

* cleanup; tests draft

* sample generation, tests fixes

* fixes

* nits

* add inference support; add auto-mask token support

* nits

* nits

* progress

* simplify logging

* lint

* prefix args with diffusion_

* coderabbito

* tests fix

* nit

* nits

* cleanup + nits

* nits

* fix SFT sample gen

* fixes

* fix

* comments

* comments

* lint

* reward model lora fix

* cleanup; fix pretraining_dataset case

* gradio inference

* update cfgs

* update cfgs

* train, generation parity, cleanup

* fix

* simplify

* test

* test fix
2025-09-10 20:27:00 -04:00
NanoCode012
b71482cec5 Feat: add hunyuan v1 (#3016)
* feat: add hunyuan cce support

* feat: update cce docs

* feat: add multipack support for granite and hunyuan

* feat: add hunyuan docs and example config

* feat: update readme instructions to include CCE installation

* fix: chat template log appearing despite tokenizer already having template

* feat: add vram usage

* fix: remove duplicate cce install

* fix: use latest commit of PR in case rebased/pushed

* Revert "fix: use latest commit of PR in case rebased/pushed"

This reverts commit 8b60aa00de.

* feat: update doc as upstream merged
2025-09-10 09:03:30 +07:00
NanoCode012
79103b01ca Feat: add seedoss (#3104) [skip ci]
* feat: add seedoss cce

* feat: add seedoss config and docs

* fix: shouldn't have target modules with target linear

* feat: add vram numbers

* fix: hf link

* fix: name

* fix: support multipack seedoss

* fix: merge error

* feat: update seedoss instructions for transformers release
2025-09-10 09:01:02 +07:00
salman
9640338d37 Default include_tkps to true (#3134)
* default true

* force e2e

* causal trainer only

* fix eval loggin [skip-ci]

* revert setup.py

* force tests

* guarding

* guarding

* fix test case

* use evaluate [skip-e2e]

* use evaluate [skip-e2e]

* kick off ci

* fixing

* reverting
2025-09-09 10:50:21 -04:00
Wing Lian
b5d4c7ff54 allow 1% deviation for codecov (#3138) [skip ci] 2025-09-07 11:01:03 -04:00
Seungduk Kim
8fd9221f13 Add ipo as an rl type that shares DPODataset config (#3128)
* Add `ipo` as an `rl` type that shares DPODataset config

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-09-07 10:49:10 -04:00
github-actions[bot]
bf00f29f3a chore: update pre-commit hooks (#3137) [skip ci]
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
2025-09-07 10:33:20 -04:00
NanoCode012
1d32278755 feat: upgrade transformers to v4.56.1 (#3127)
* feat: upgrade transformers to v4.56

* fix handling of CP/SP now that position_ids are default even for unpacked sequences

* feat: monkeypatch list_repo_templates

* fix: apply patch for tests only

* see if updated main works at least

* fix: update to patch release and remove monkeypatch

* remove fsdp2 eval patch

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-09-05 11:00:54 -04:00
NanoCode012
c6ae5c43cb fix: chat template jinja file not being loaded during inference (#3112)
* fix: chat template jinja file not being loaded during inference

* fix: bot comment
2025-09-03 16:25:09 -04:00
yardenhoch
efa1da52d5 Center rewards coefficient (#3124)
* feat: add center_rewards_coefficient for reward modeling

- Add center_rewards_coefficient parameter to Pydantic schema with paper reference
- Pass parameter through base builder and causal builder to training args
- Add documentation section with usage examples and theoretical background
- Enable parameter in reward modeling example configs with recommended value
- Enables reward centering for improved training stability in RLHF workflows

Implements auxiliary loss from Eisenstein et al. 2023 (https://huggingface.co/papers/2312.09244)
to incentivize mean-zero reward outputs without post-training normalization.

* Update description

* test: add unit tests for center_rewards_coefficient integration

* Update src/axolotl/core/builders/base.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update docs/reward_modelling.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update docs/reward_modelling.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* reference to TRL documentation.

* add new reward model configuration for qwen3 with comprehensive parameters

* Verified center_rewards_coefficient is correctly passed through the trainer builder to training arguments.

* Refactor reward modeling documentation to consolidate information on center_rewards_coefficient

* Remove unit tests for center_rewards_coefficient integration as part of codebase cleanup.

* linting

* nit

* Apply suggestions from code review

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* lint

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
2025-09-03 16:22:37 -04:00
mhenrichsen
48db520d92 Create 270m-qlora.yml (#3075) [skip ci]
Adds 270m gemma3 qlora
2025-09-03 16:20:32 -04:00
NanoCode012
53a0c1f39c feat: add peft_trainable_token_indices (#3062)
* feat: add peft_trainable_token_indices

* feat: add warning compat with fix_untrained_tokens
2025-09-03 01:48:01 -04:00
github-actions[bot]
4cc6038d52 chore: update pre-commit hooks (#3122) [skip ci]
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
2025-09-03 01:41:34 -04:00
NanoCode012
e48aa8a5b1 feat(doc): improve visibility for colab notebooks (#3110) [skip ci]
* feat: improve visibility for colab notebooks

* fix: link to GH colab

* feat: change to badge and move higher
2025-09-03 01:40:53 -04:00
xuyifann
24aba5caca Clamping the len of dataloader to minimum of 1 (#3100) [skip ci]
* Clamping the len of dataloader to minimum of 1

* linter reformat
2025-09-03 01:40:27 -04:00
Wing Lian
06bebcb65f run cu128-2.8.0 e2e tests on B200 (#3126)
* run cu128-2.8.0 e2e tests on B200

* not an int 🤦

* fix yaml
2025-09-02 13:13:23 -04:00
Dan Saunders
231a67e70b Streaming SFT support (#3101)
* working

* fixes

* deprecate --iterable; cleanup

* pretrain_multipack_buffer_size -> streaming_multipack_buffer_size

* improvements

* tests

* remove unused

* docs, examples

* nit

* nit

* add val_set_size validation

* val

* nit

* min

* coderabbito

* cleanup

* nit

* add depr warning, cleanup

* nit

* fix test, fix quarto

* fix

* review comments

* review comments

* fix
2025-09-02 12:08:44 -04:00
Wing Lian
0094a2d744 support for tiledmlp for GPT-OSS (#3116)
* fix use of flex attn kwargs and add support for tiledmlp for GPT-OSS

* add logging back

* update deps
2025-08-29 13:52:49 -04:00
Wing Lian
7ed40f1d70 automatically set env vars for single gpu deepspeed zero3 (#3118) [skip ci]
* automatically set env vars for single gpu deepspeed zero3

* use setdefault
2025-08-29 13:36:47 -04:00
VED
5b6ec2820f patch for ds_grads_remaining in deepspeed (#3102) [skip ci]
* patch deepspeed

* deepspeed patch for ds_grads_remaining

* patch in Patchmanager

* chore: lint

* deepseed utils

* chore2

* patch ds_grads_remaining chore

* chore lint

* chore lint

* remove torch.nn patch

* lint

* Update src/axolotl/monkeypatch/utils.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* patched with checkpointwarapper

* lint

* only apply deepspeed patch when using activation offloading

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-08-29 12:12:09 -04:00
Wing Lian
6afba3871d Add support for PyTorch 2.8.0 (#3106)
* Add support for PyTorch 2.8.0

* loosen triton requirements

* handle torch 2.8.0 in setup.py

* fix versions

* no vllm for torch 2.8.0

* remove comment

Co-authored-by: NanoCode012 <nano@axolotl.ai>

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-08-28 09:10:40 -04:00
Dan Saunders
dc338c3b0e Update .coderabbit.yaml (#3109) [skip ci]
Oops, should be false.
2025-08-27 09:50:52 -04:00
salman
d0d2fc5606 Tokens per second logging [skip-e2e] (#3072) 2025-08-27 09:10:14 +01:00
Wing Lian
e1131e9619 make always skip_move_to_device default as true (#3084) 2025-08-26 09:30:22 -04:00
Wing Lian
c4c4b90638 add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json (#3093)
* add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json

* fix test import
2025-08-26 09:30:04 -04:00
Wing Lian
0e9945e3b9 deploy training jobs to baseten w truss in axolotl cli (#3086) [skip ci]
* deploy training jobs to baseten w truss in axolotl cli

* cleanup
2025-08-26 09:29:50 -04:00
NanoCode012
0de254a0d0 feat: add gemma3_text attention handling for lora kernels (#3103) 2025-08-26 16:47:26 +07:00
143 changed files with 14498 additions and 1598 deletions

View File

@@ -12,6 +12,6 @@ reviews:
auto_review:
enabled: true
drafts: false
auto_incremental_review: true
auto_incremental_review: false
chat:
auto_reply: true

View File

@@ -36,6 +36,11 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -110,6 +115,11 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -169,6 +179,12 @@ jobs:
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
is_latest:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -33,13 +33,6 @@ jobs:
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -47,6 +40,13 @@ jobs:
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras: fbgemm-gpu
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:

View File

@@ -55,7 +55,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
timeout-minutes: 20
steps:
@@ -130,7 +130,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
timeout-minutes: 20
steps:
@@ -240,7 +240,7 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
@@ -298,6 +298,13 @@ jobs:
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
num_gpus: 1
gpu_type: "B200"
axolotl_extras: fbgemm-gpu
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -318,6 +325,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
@@ -334,10 +342,10 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
steps:

3
.gitignore vendored
View File

@@ -190,3 +190,6 @@ out/
# vim
*.swp
# scm auto-versioning
src/axolotl/_version.py

View File

@@ -11,7 +11,7 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.9
rev: v0.12.12
hooks:
- id: ruff
args: [--fix]

View File

@@ -1,6 +1,6 @@
cff-version: 1.2.0
type: software
title: "Axolotl: Post-Training for AI Models"
title: "Axolotl: Open Source LLM Post-Training"
message: "If you use this software, please cite it as below."
authors:
- name: "Axolotl maintainers and contributors"

View File

@@ -5,6 +5,9 @@
<img alt="Axolotl" src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
</picture>
</p>
<p align="center">
<strong>A Free and Open Source LLM Fine-tuning Framework</strong><br>
</p>
<p align="center">
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
@@ -17,6 +20,7 @@
<br/>
<a href="https://discord.com/invite/HhrNrHJPRb"><img src="https://img.shields.io/badge/discord-7289da.svg?style=flat-square&logo=discord" alt="discord" style="height: 20px;"></a>
<a href="https://twitter.com/axolotl_ai"><img src="https://img.shields.io/twitter/follow/axolotl_ai?style=social" alt="twitter" style="height: 20px;"></a>
<a href="https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google-colab" style="height: 20px;"></a>
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
@@ -49,20 +53,21 @@
## ✨ Overview
Axolotl is a tool designed to streamline post-training for various AI models.
Axolotl is a free and open-source tool designed to streamline post-training and fine-tuning for the latest large language models (LLMs).
Features:
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
- **Multiple Model Support**: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub.
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, and audio models like Voxtral with image, video, and audio support.
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
## 🚀 Quick Start
## 🚀 Quick Start - LLM Fine-tuning in Minutes
**Requirements**:
@@ -70,6 +75,10 @@ Features:
- Python 3.11
- PyTorch ≥2.6.0
### Google Colab
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa)
### Installation
#### Using pip
@@ -155,7 +164,7 @@ If you use Axolotl in your research or projects, please cite it as follows:
```bibtex
@software{axolotl,
title = {Axolotl: Post-Training for AI Models},
title = {Axolotl: Open Source LLM Post-Training},
author = {{Axolotl maintainers and contributors}},
url = {https://github.com/axolotl-ai-cloud/axolotl},
license = {Apache-2.0},

View File

@@ -153,7 +153,7 @@ quartodoc:
- utils.distributed
- utils.dict
- utils.optimizers.adopt
- utils.data.pretraining
- utils.data.streaming
- utils.data.sft
- utils.quantization
- title: Schemas
@@ -272,6 +272,7 @@ website:
contents:
- docs/batch_vs_grad.qmd
- docs/dataset_preprocessing.qmd
- docs/streaming.qmd
- docs/multipack.qmd
- docs/mixed_precision.qmd
- docs/optimizers.qmd
@@ -284,6 +285,7 @@ website:
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd
- docs/gradient_checkpointing.qmd
- docs/moe_backends.md
- docs/nd_parallelism.qmd
- section: "Troubleshooting"

View File

@@ -57,7 +57,8 @@ VOLUME_CONFIG = {
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = f"L40S:{N_GPUS}"
GPU_TYPE = os.environ.get("GPU_TYPE", "L40S")
GPU_CONFIG = f"{GPU_TYPE}:{N_GPUS}"
def run_cmd(cmd: str, run_folder: str):

View File

@@ -12,7 +12,7 @@ coverage:
default:
# basic
target: auto
threshold: 0%
threshold: 1%
base: auto
# advanced
branches: null
@@ -27,7 +27,7 @@ coverage:
default:
# basic
target: auto
threshold: 0%
threshold: 1%
base: auto
# advanced
branches: null

View File

@@ -134,7 +134,7 @@ For providers supporting Docker:
### Google Colab {#sec-colab}
Use our [example notebook](../examples/colab-notebooks/colab-axolotl-example.ipynb).
[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa)
## Platform-Specific Instructions {#sec-platform-specific}

18
docs/moe_backends.md Normal file
View File

@@ -0,0 +1,18 @@
MoE Backends in Axolotl
Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via the training config (YAML):
- Set `moe_backend: auto|torch_grouped|naive`
Behavior
- auto (default): prefers PyTorch 2.8+ grouped GEMM; otherwise naive.
- torch_grouped: targets PyTorch 2.8+ grouped GEMM (H100/SM90+ recommended).
- naive: keeps the reference per-expert loop.
Notes
- Current implementation wires the backend selector and routes Mixtral MoE through it. Torch grouped uses cuBLASLt grouped GEMM when available; otherwise, the code falls back to the naive per-expert loop.
- No changes to training scripts are required; selection happens inside the model forward.
Example
moe_backend: torch_grouped
accelerate launch -m axolotl.cli.train path/to/config.yaml

View File

@@ -63,15 +63,6 @@ Start from Stage 1 -> Stage 2 -> Stage 3.
:::
::: {.callout-tip}
Using ZeRO Stage 3 with Single-GPU training
ZeRO Stage 3 can be used for training on a single GPU by manually setting the environment variables:
`WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=0.0.0.0 MASTER_PORT=29500`
:::
## Fully Sharded Data Parallel (FSDP) {#sec-fsdp}
::: {.callout-note}

View File

@@ -23,10 +23,17 @@ To enable QAT in axolotl, add the following to your configuration file:
```yaml
qat:
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4", "int8", "float8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4", "fp8", and "nvfp4".
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
```
We support the following quantization schemas:
- `Int4WeightOnly` (requires the `fbgemm-gpu` extra when installing Axolotl)
- `Int8DynamicActivationInt4Weight`
- `Float8DynamicActivationFloat8Weight`
- `Float8DynamicActivationInt4Weight`
- `NVFP4`
Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize`](./quantize.qmd) command to do this.

View File

@@ -22,8 +22,8 @@ Quantization is configured using the `quantization` key in your configuration fi
```yaml
base_model: # The path to the model to quantize.
quantization:
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4", "int8", "float8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4", "fp8", and "nvfp4".
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
@@ -39,9 +39,8 @@ you used to train the model:
# qat.yml
qat:
activation_dtype: int8
weight_dtype: int8
weight_dtype: int4
group_size: 256
quantize_embedding: true
output_dir: # The path to the output directory used during training where the final checkpoint has been saved.
```
@@ -51,3 +50,11 @@ axolotl quantize qat.yml
```
This ensures that an identical quantization configuration is used to quantize the model as was used to train it.
::: {.callout-note}
If you have configured pushing to hub with `hub_model_id`, your model hub name will have the quantization schema appended to it,
e.g. `axolotl-ai-cloud/qat-nvfp4-llama3B` will become `axolotl-ai-cloud/qat-nvfp4-llama3B-nvfp4w`
:::

View File

@@ -11,6 +11,7 @@ We support the reward modelling techniques supported by `trl`.
### (Outcome) Reward Models
Outcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step).
For improved training stability, you can use the `center_rewards_coefficient` parameter to encourage mean-zero reward outputs ([see TRL docs](https://huggingface.co/docs/trl/v0.10.1/en/reward_trainer#centering-rewards)).
```yaml
base_model: google/gemma-2-2b

120
docs/streaming.qmd Normal file
View File

@@ -0,0 +1,120 @@
---
title: Streaming Datasets
description: How to use streaming mode for large-scale datasets and memory-efficient training
order: 10
---
Streaming enables memory-efficient training with large datasets by loading data
incrementally rather than loading the entire dataset into memory at once.
Use streaming when:
- Your dataset is too large to fit in memory (e.g. when you're doing pretraining with massive text corpora)
- You want to start training immediately without preprocessing the entire dataset
Streaming works with both remote and locally stored datasets!
::: {.callout-note}
Streaming currently only supports a single dataset. Multi-dataset support will be added soon.
:::
## Configuration
### Basic Streaming
Enable streaming mode by setting the `streaming` flag:
```yaml
streaming: true
```
### Pretraining with Streaming
For pretraining tasks, streaming is automatically enabled when using `pretraining_dataset`:
```yaml
pretraining_dataset:
- path: HuggingFaceFW/fineweb-edu
type: pretrain
text_column: text
split: train
# Optionally, enable sample packing
streaming_multipack_buffer_size: 10000
sample_packing: true
```
### SFT with Streaming
For supervised fine-tuning with streaming:
```yaml
streaming: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
split: train
# Optionally, enable sample packing
streaming_multipack_buffer_size: 10000
sample_packing: true
```
## Configuration Options
### `streaming_multipack_buffer_size`
Controls the buffer size for multipack streaming (default: 10,000). This determines how
many samples are buffered before packing. Larger buffers can improve packing efficiency
but use more memory.
### `shuffle_merged_datasets`
When enabled, shuffles the streaming dataset using the buffer. This requires additional
memory for the shuffle buffer.
## Sample Packing with Streaming
Sample packing is supported for streaming datasets. When enabled, multiple samples are
packed into a single sequence to maximize GPU utilization:
```yaml
sample_packing: true
streaming_multipack_buffer_size: 10000
# For SFT: attention is automatically isolated between packed samples
# For pretraining: control with pretrain_multipack_attn
pretrain_multipack_attn: true # prevent cross-attention between packed samples
```
For more information, see our [documentation](multipack.qmd) on multipacking.
## Important Considerations
### Memory Usage
While streaming reduces memory usage compared to loading entire datasets, you still need
to consider:
- You can control the memory usage by adjusting `streaming_multipack_buffer_size`
- Sample packing requires buffering multiple samples
- Shuffling requires additional memory for the shuffle buffer
### Performance
- Streaming may have slightly higher latency compared to preprocessed datasets, as samples are processed on-the-fly
- Network speed and disk read speed are important when streaming from remote sources or a local dataset, respectively
- Consider using `axolotl preprocess` for smaller or more frequently used datasets
### Evaluation Datasets
Evaluation datasets are not streamed to ensure consistent evaluation metrics. They're
loaded normally even when training uses streaming.
## Examples
See the `examples/streaming/` directory for complete configuration examples:
- `pretrain.yaml`: Pretraining with streaming dataset
- `sft.yaml`: Supervised fine-tuning with streaming

View File

@@ -0,0 +1,10 @@
provider: baseten
project_name:
secrets:
- HF_TOKEN
- WANDB_API_KEY
gpu: h100
gpu_count: 8
node_count: 1

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5\""
]
},
{
@@ -176,8 +176,8 @@
}
],
"source": [
"from axolotl.utils.dict import DictDefault\n",
"from axolotl.cli.config import load_cfg\n",
"from axolotl.utils.dict import DictDefault\n",
"\n",
"# Axolotl provides full control and transparency over model and training configuration\n",
"config = DictDefault(\n",
@@ -251,10 +251,10 @@
},
"outputs": [],
"source": [
"from axolotl.utils import patch_optimized_env\n",
"from axolotl.utils import set_pytorch_cuda_alloc_conf\n",
"\n",
"# speedup downloads from HF 🤗 and set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n",
"patch_optimized_env()"
"# Set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n",
"set_pytorch_cuda_alloc_conf()"
]
},
{

View File

@@ -20,7 +20,13 @@ pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Run the finetuning example:
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
```bash
python scripts/cutcrossentropy_install.py | sh
```
3. Run the finetuning example:
```bash
axolotl train examples/devstral/devstral-small-qlora.yml

View File

@@ -0,0 +1,68 @@
base_model: google/gemma-3-270m-it
# optionally might have model_type or tokenizer_type
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
load_in_8bit: false
load_in_4bit: true
# huggingface repo
chat_template: gemma3
eot_tokens:
- <end_of_turn>
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -106,6 +106,16 @@ See [Nanobit/text-tools-2k-test](https://huggingface.co/datasets/Nanobit/text-to
Refer to [our docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use) for more info.
### Thinking and chat_template masking conflict
OpenAIs Harmony template hides `thinking` in all non-final turns, which conflicts with Axolotls `chat_template` masking.
If your dataset has `thinking` content mid-turn, there are two paths we recommend:
- Train only on the last turn. This can be accomplished via chat_template's [train on last doc](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#training-on-last-message).
- Adjust your dataset to only have `thinking` content in the last turn.
### TIPS
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).

View File

@@ -0,0 +1,85 @@
# Finetune HunYuan with Axolotl
Tencent released a family of opensource models called HunYuan with varying parameter scales of 0.5B, 1.8B, 4B, and 7B scale for both Pre-trained and Instruct variants. The models can be found at [HuggingFace](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as HunYuan is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Run the finetuning example:
```bash
axolotl train examples/hunyuan/hunyuan-v1-dense-qlora.yaml
```
This config uses about 4.7 GB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### Dataset
HunYuan Instruct models can choose to enter a slow think or fast think pattern. For best performance on fine-tuning their Instruct models, your dataset should be adjusted to match their pattern.
```python
# fast think pattern
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "/no_think What color is the sun?" },
{"role": "assistant", "content": "<think>\n\n</think>\n<answer>\nThe sun is yellow.\n</answer>"}
]
# slow think pattern
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "/no_think What color is the sun?" },
{"role": "assistant", "content": "<think>\nThe user is asking about the color of the sun. I need to ...\n</think>\n<answer>\nThe sun is yellow.\n</answer>"}
]
```
### TIPS
- For inference, the official Tencent team recommends
```json
{
"do_sample": true,
"top_k": 20,
"top_p": 0.8,
"repetition_penalty": 1.05,
"temperature": 0.7
}
```
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [Tencent HunYuan Blog](https://hunyuan.tencent.com/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,64 @@
base_model: tencent/Hunyuan-0.5B-Instruct
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -15,20 +15,18 @@ liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
split: train[:95%]
output_dir: ./outputs/qat_out/
dataset_prepared_path: ./outputs/qat_out/dataset_prepared
sample_packing: true
sequence_len: 512
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
sample_packing: false
sequence_len: 8192
flash_attention: true
qat:
activation_dtype: int8
@@ -67,7 +65,7 @@ fsdp:
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_cpu_ram_efficient_loading: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
@@ -76,6 +74,6 @@ fsdp_config:
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|end_of_text|>
pad_token: <|finetune_right_pad_id|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,64 @@
base_model: meta-llama/Llama-3.2-3B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
split: train[:95%]
output_dir: ./outputs/qat_out/
dataset_prepared_path: ./outputs/dataset_prepared
sequence_len: 8192
flash_attention: true
qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_checkpointing: true
gradient_accumulation_steps: 1
micro_batch_size: 64
num_epochs: 1
optimizer: adamw_torch_fused
cosine_constant_lr_ratio: 0
cosine_min_lr_ratio: 1.0
learning_rate: 2e-5
save_only_model: true
bf16: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,56 @@
base_model: meta-llama/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
pretraining_dataset:
- path: wikitext
name: wikitext-103-raw-v1
type: completion
field: text
plugins:
- axolotl.integrations.diffusion.DiffusionPlugin
diffusion:
noise_schedule: cosine
min_mask_ratio: 0.15
max_mask_ratio: 0.85
num_diffusion_steps: 128
eps: 5e-4
importance_weighting: true
mask_token_id: 128002
generate_samples: true
generation_interval: 250
output_dir: ./outputs/model-out
sequence_len: 512
sample_packing: true
gradient_accumulation_steps: 8
micro_batch_size: 4
max_steps: 10000
warmup_ratio: 0.1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 3e-4
sdp_attention: true
bf16: auto
tf32: true
logging_steps: 1
save_strategy: steps
save_steps: 1000
special_tokens:
pad_token: "<|end_of_text|>"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,59 @@
base_model: meta-llama/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
val_set_size: 0.05
plugins:
- axolotl.integrations.diffusion.DiffusionPlugin
diffusion:
noise_schedule: cosine
min_mask_ratio: 0.1
max_mask_ratio: 0.9
num_diffusion_steps: 128
eps: 1e-3
importance_weighting: true
mask_token_id: 128002
generate_samples: true
generation_interval: 250
output_dir: ./outputs/model-out
sequence_len: 512
sample_packing: true
eval_sample_packing: true
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 1
warmup_steps: 0.1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 1e-5
bf16: auto
tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
sdp_attention: true
logging_steps: 1
save_strategy: best
eval_strategy: epoch
special_tokens:
pad_token: "<|end_of_text|>"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -18,7 +18,13 @@ pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Run the finetuning example:
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
```bash
python scripts/cutcrossentropy_install.py | sh
```
3. Run the finetuning example:
```bash
axolotl train examples/magistral/magistral-small-qlora.yaml

View File

@@ -0,0 +1,53 @@
base_model: Qwen/Qwen1.5-MoE-A2.7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true
# Keep VRAM low
load_in_8bit: false
load_in_4bit: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/qwen2-moe-qlora-10gb
# Train small to fit 10GB
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 5
flash_attention: true
warmup_ratio: 0.03
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
model_config:
output_router_logits: true
special_tokens:

View File

@@ -0,0 +1,44 @@
base_model: Skywork/Skywork-Reward-V2-Qwen3-8B
model_type: AutoModelForSequenceClassification
num_labels: 1
reward_model: true
center_rewards_coefficient: 0.01 # Incentivize mean-zero rewards for improved stability
chat_template: qwen3
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 8192
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
deepspeed: deepspeed_configs/zero1.json
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
eval_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: linear
learning_rate: 0.00002
bf16: true
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
warmup_ratio: 0.1
logging_steps: 1
weight_decay: 0.01

View File

@@ -0,0 +1,54 @@
# Finetune ByteDance's Seed-OSS with Axolotl
[Seed-OSS](https://huggingface.co/collections/ByteDance-Seed/seed-oss-68a609f4201e788db05b5dcd) are a series of 36B parameter open source models trained by ByteDance's Seed Team.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Seed-OSS is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install Cut Cross Entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Run the finetuning example:
```bash
axolotl train examples/seed-oss/seed-oss-36b-qlora.yaml
```
This config uses about 27.7 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- For inference, the official Seed Team recommends `top_p=0.95` and `temperature=1.1`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [ByteDance Seed Website](https://seed.bytedance.com/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,56 @@
base_model: ByteDance-Seed/Seed-OSS-36B-Instruct
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,50 @@
# Streaming Dataset Examples
This directory contains example configurations for using Axolotl's streaming dataset
functionality, which enables memory-efficient training with large datasets.
## Examples
Run the following examples with e.g. `axolotl train examples/streaming/sft.yaml`; no
`axolotl preprocess` required!
### Pretraining (`pretrain.yaml`)
Demonstrates streaming configuration for pretraining tasks using the fineweb-edu dataset
with SmolLM2-135M.
- Uses `pretraining_dataset` configuration for automatic streaming
- Multipack attention control to prevent cross-attention between packed sequences
- Buffer size configuration for memory management
### SFT (`sft.yaml`)
Shows how to use streaming for supervised fine-tuning with the Alpaca dataset.
- Explicit `streaming: true` flag for SFT datasets
- Memory-efficient training on instruction datasets
- Evaluation datasets are currently not streamed
## Key Configuration Options
### `streaming`
- Enables streaming mode for standard datasets
- Automatically enabled for `pretraining_dataset`
### `streaming_multipack_buffer_size`
- Controls buffer size for sample packing (default: 10,000)
- Larger values improve packing efficiency but use more memory
- Adjust based on available memory
### `shuffle_merged_datasets`
- Enables shuffling of streaming datasets
- Requires additional memory for shuffle buffer
### `sample_packing`
- Packs multiple samples into single sequences
- Minimize per-step padding tokens
## Performance Tips
- Download small / frequently-used datasets locally for better performance
- Larger buffer sizes improve packing efficiency

View File

@@ -0,0 +1,57 @@
base_model: HuggingFaceTB/SmolLM2-135M
# Streaming pretraining configuration
pretraining_dataset:
- path: HuggingFaceFW/fineweb-edu
name: sample-10BT
type: pretrain
text_column: text
split: train
# Streaming-specific settings
streaming_multipack_buffer_size: 10000
shuffle_merged_datasets: true
# Training configuration
max_steps: 1000
output_dir: ./outputs/smollm2-135m-pretrain-streaming
# Sequence and packing settings
sequence_len: 1024
sample_packing: true
pretrain_multipack_attn: true # Prevent cross-attention between packed sequences
flash_attention: true
# Batch size settings
gradient_accumulation_steps: 8
micro_batch_size: 1
# Optimizer and scheduler
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 5e-4
warmup_ratio: 0.1
weight_decay: 0.01
# Precision and performance
bf16: auto
tf32: true
# Logging and checkpointing
logging_steps: 10
save_strategy: steps
save_steps: 250
save_total_limit: 3
# Weights & Biases (optional)
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# Special tokens
special_tokens:
pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,55 @@
base_model: HuggingFaceTB/SmolLM2-135M
# Dataset configuration
datasets:
- path: tatsu-lab/alpaca
type: alpaca
split: train
# Streaming-specific settings
streaming: true
streaming_multipack_buffer_size: 10000
shuffle_merged_datasets: true
# Training configuration
max_steps: 1000
output_dir: ./outputs/smollm2-135m-sft-streaming
# Sequence and packing settings
sequence_len: 1024
sample_packing: true
flash_attention: true
# Batch size settings
gradient_accumulation_steps: 4
micro_batch_size: 1
# Optimizer and scheduler
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-4
warmup_ratio: 0.1
weight_decay: 0.0
# Precision and performance
bf16: auto
tf32: true
# Logging and checkpointing
logging_steps: 10
save_strategy: steps
save_steps: 100
save_total_limit: 3
# Weights & Biases (optional)
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# Special tokens
special_tokens:
pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -22,6 +22,9 @@ pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
# audio
pip3 install librosa==0.11.0
pip3 install 'mistral_common[audio]==1.8.3'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
3. Run the finetuning example:

View File

@@ -32,7 +32,7 @@ line-length = 88
target-version = "py310"
[tool.ruff.lint]
select = ["E", "F", "W", "C90", "B"]
select = ["E", "F", "W", "C90", "B", "I"]
ignore = [
"E203", # Whitespace before ':'
"E501", # Line too long

View File

@@ -2,8 +2,7 @@
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.47.0
# triton 3.4.0 is not compatible with CCE
triton>=3.0.0,<3.4.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
@@ -14,12 +13,12 @@ packaging==23.2
huggingface_hub>=0.33.0
peft>=0.17.0
transformers==4.55.3
transformers==4.56.1
tokenizers>=0.21.1
accelerate==1.10.0
accelerate==1.10.1
datasets==4.0.0
deepspeed>=0.17.0
trl==0.21.0
trl==0.23.0
hf_xet==1.1.5
kernels==0.9.0
trackio
@@ -65,7 +64,7 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.12.0
torchao==0.13.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6

209
scripts/bench_moe.py Normal file
View File

@@ -0,0 +1,209 @@
#!/usr/bin/env python
"""Benchmark Hugging Face Qwen2 MoE block with and without grouped_mm."""
from __future__ import annotations
import argparse
import sys
import time
import weakref
from pathlib import Path
import torch
import torch._dynamo as dynamo
try:
from axolotl.kernels.moe import torch_grouped as tg
except Exception: # pragma: no cover
tg = None
def bench(run, *, iters: int, warmup: int, sync: bool = True) -> float:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for _ in range(warmup):
run()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times = []
for _ in range(iters):
if sync and device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
run()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000.0)
return sum(times) / len(times)
def estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
return 6.0 * tokens * top_k * hidden * inter
def load_hf_block(
hidden: int,
inter: int,
experts: int,
top_k: int,
*,
device: torch.device,
dtype: torch.dtype,
):
project_root = Path(__file__).resolve().parents[2]
transformers_src = project_root / "transformers" / "src"
if transformers_src.exists() and str(transformers_src) not in sys.path:
sys.path.append(str(transformers_src))
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
cfg = Qwen2MoeConfig(
hidden_size=hidden,
moe_intermediate_size=inter,
shared_expert_intermediate_size=inter,
num_experts=experts,
num_experts_per_tok=top_k,
norm_topk_prob=True,
qkv_bias=True,
)
block = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped.load_state_dict(block.state_dict())
return block, block_grouped
def main() -> None:
p = argparse.ArgumentParser(description="Qwen2 MoE grouped_mm benchmark")
p.add_argument("--bsz", type=int, default=8)
p.add_argument("--seq", type=int, default=1024)
p.add_argument("--hidden", type=int, default=4096)
p.add_argument("--inter", type=int, default=14336)
p.add_argument("--experts", type=int, default=32)
p.add_argument("--top_k", type=int, default=4)
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=50)
p.add_argument("--warmup", type=int, default=10)
p.add_argument("--profile", action="store_true")
p.add_argument(
"--compile",
action="store_true",
help="Torch.compile both paths before benchmarking",
)
args = p.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[args.dtype]
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
block_naive, block_grouped = load_hf_block(
args.hidden,
args.inter,
args.experts,
args.top_k,
device=device,
dtype=dtype,
)
tokens = args.bsz * args.seq
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
print(
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} "
f"experts={args.experts} top_k={args.top_k}"
)
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
# Optional torch.compile
run_grouped_impl = None
if args.compile:
dynamo.config.capture_scalar_outputs = True
dynamo.config.allow_unspec_int_on_nn_module = True
try:
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
except Exception as exc: # pragma: no cover
print(f"torch.compile naive failed ({exc}); using eager")
else:
def grouped_forward(inp, *, block=block_grouped):
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
inp, block.gate, block.experts, block.top_k
)
return y
try:
run_grouped_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
except Exception as exc: # pragma: no cover
print(f"torch.compile grouped failed ({exc}); using eager")
run_grouped_impl = None
def run_naive(block=block_naive, data=x):
y, _ = block(data)
return y
def run_grouped(block=block_grouped, data=x, impl=run_grouped_impl):
if impl is not None:
return impl(data)
if tg is None or not tg.available():
return torch.empty(0)
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(data, block.gate, block.experts, block.top_k)
return y if y is not None else torch.empty(0)
t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup)
tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12)
print(
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s"
)
with torch.no_grad():
y_ref = run_naive()
if tg is None or not tg.available():
print("torch_grouped\tN/A (unavailable)")
return
y_grouped = run_grouped()
if y_grouped.numel() == 0:
print("torch_grouped\tN/A (op not callable)")
return
t_grouped = bench(run_grouped, iters=args.iters, warmup=args.warmup)
tflops_grouped = flops_total / ((t_grouped / 1000.0) * 1e12)
speedup = t_naive / t_grouped
print(
f"torch_grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000.0):.1f} tok/s\t"
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
)
diff = (y_ref.float() - y_grouped.float()).abs()
print(
"torch_grouped_check: "
f"max_abs={diff.max().item():.3e} mean_abs={diff.mean().item():.3e} "
f"rel_l2={(diff.pow(2).sum() / (y_ref.float().pow(2).sum() + 1e-12)).sqrt().item():.3e}"
)
if args.profile:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
) as prof:
run_naive()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
) as prof:
run_grouped()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
if __name__ == "__main__":
main()

311
scripts/bench_moe_sweep.py Normal file
View File

@@ -0,0 +1,311 @@
#!/usr/bin/env python
"""Sweep grouped_mm vs naive performance for Qwen2 MoE block."""
from __future__ import annotations
import argparse
import csv
import sys
import time
import weakref
from dataclasses import dataclass
from pathlib import Path
from typing import List
import torch
import torch._dynamo as dynamo
try:
from axolotl.kernels.moe import torch_grouped as tg
except Exception: # pragma: no cover
tg = None
def _parse_list(arg: str) -> List[int]:
return [int(v) for v in arg.split(",") if v]
def _bench(run, *, iters: int, warmup: int, device: torch.device) -> float:
for _ in range(warmup):
run()
if device.type == "cuda":
torch.cuda.synchronize()
times: List[float] = []
for _ in range(iters):
if device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
run()
if device.type == "cuda":
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000.0)
return sum(times) / len(times)
def _estimate_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
return 6.0 * tokens * top_k * hidden * inter
def _load_block(
hidden: int,
inter: int,
experts: int,
top_k: int,
*,
device: torch.device,
dtype: torch.dtype,
):
project_root = Path(__file__).resolve().parents[2]
transformers_src = project_root / "transformers" / "src"
if transformers_src.exists() and str(transformers_src) not in sys.path:
sys.path.append(str(transformers_src))
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
cfg = Qwen2MoeConfig(
hidden_size=hidden,
moe_intermediate_size=inter,
shared_expert_intermediate_size=inter,
num_experts=experts,
num_experts_per_tok=top_k,
norm_topk_prob=True,
qkv_bias=True,
)
block = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped.load_state_dict(block.state_dict())
return block, block_grouped
@dataclass
class Result:
bsz: int
seq: int
hidden: int
inter: int
experts: int
top_k: int
dtype: str
naive_ms: float
grouped_ms: float
speedup: float
naive_tflops: float
grouped_tflops: float
max_abs: float
mean_abs: float
rel_l2: float
def main() -> None:
p = argparse.ArgumentParser(description="Grouped MoE sweep")
p.add_argument("--batch-sizes", default="4,8,16")
p.add_argument("--seq-lens", default="512,1024,2048")
p.add_argument("--hidden", default="2048,4096")
p.add_argument("--inter", default="5632,8192,14336")
p.add_argument("--experts", default="8,16,32")
p.add_argument("--top-k", default="1,2,4")
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=25)
p.add_argument("--warmup", type=int, default=5)
p.add_argument("--csv", type=Path, default=None)
p.add_argument("--compile", action="store_true")
args = p.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[args.dtype]
if tg is None or not tg.available():
print("torch_grouped unavailable; sweep aborted")
return
bs_list = _parse_list(args.batch_sizes)
seq_list = _parse_list(args.seq_lens)
hidden_list = _parse_list(args.hidden)
inter_list = _parse_list(args.inter)
expert_list = _parse_list(args.experts)
topk_list = _parse_list(args.top_k)
results: List[Result] = []
print(
"bsz\tseq\thidden\tinter\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"
)
for bsz in bs_list:
for seq in seq_list:
tokens = bsz * seq
for hidden in hidden_list:
for inter in inter_list:
for experts in expert_list:
for top_k in topk_list:
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
block_naive, block_grouped = _load_block(
hidden,
inter,
experts,
top_k,
device=device,
dtype=dtype,
)
x = torch.randn(
bsz, seq, hidden, device=device, dtype=dtype
)
compiled_impl = None
if args.compile:
dynamo.config.capture_scalar_outputs = True
dynamo.config.allow_unspec_int_on_nn_module = True
try:
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
except Exception as exc:
print(
f"torch.compile naive failed ({exc}); using eager"
)
else:
def grouped_forward(inp, *, block=block_grouped):
block.experts._ax_parent_block_ref = (
weakref.ref(block)
) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
inp,
block.gate,
block.experts,
block.top_k,
)
return y
try:
compiled_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
except Exception as exc:
print(
f"torch.compile grouped failed ({exc}); using eager"
)
compiled_impl = None
def run_naive(block=block_naive, data=x):
y, _ = block(data)
return y
def run_grouped(
block=block_grouped, data=x, impl=compiled_impl
):
if impl is not None:
return impl(data)
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
data,
block.gate,
block.experts,
block.top_k,
)
return y
naive_ms = _bench(
run_naive,
iters=args.iters,
warmup=args.warmup,
device=device,
)
y_naive = run_naive()
grouped_ms = _bench(
run_grouped,
iters=args.iters,
warmup=args.warmup,
device=device,
)
y_grouped = run_grouped()
diff = (y_naive.float() - y_grouped.float()).abs()
res = Result(
bsz,
seq,
hidden,
inter,
experts,
top_k,
args.dtype,
naive_ms,
grouped_ms,
naive_ms / grouped_ms,
_estimate_flops(tokens, hidden, inter, top_k)
/ ((naive_ms / 1000.0) * 1e12),
_estimate_flops(tokens, hidden, inter, top_k)
/ ((grouped_ms / 1000.0) * 1e12),
diff.max().item(),
diff.mean().item(),
(
(
diff.pow(2).sum()
/ (y_naive.float().pow(2).sum() + 1e-12)
)
.sqrt()
.item()
),
)
results.append(res)
print(
f"{bsz}\t{seq}\t{hidden}\t{inter}\t{experts}\t{top_k}\t{res.naive_ms:.2f}\t"
f"{res.grouped_ms:.2f}\t{res.speedup:.2f}\t{res.naive_tflops:.2f}\t"
f"{res.grouped_tflops:.2f}\t{res.max_abs:.2e}\t{res.mean_abs:.2e}\t{res.rel_l2:.2e}"
)
if args.csv:
fieldnames = [
"bsz",
"seq",
"hidden",
"inter",
"experts",
"top_k",
"dtype",
"naive_ms",
"grouped_ms",
"speedup",
"naive_tflops",
"grouped_tflops",
"max_abs",
"mean_abs",
"rel_l2",
]
with args.csv.open("w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for r in results:
writer.writerow(
{
"bsz": r.bsz,
"seq": r.seq,
"hidden": r.hidden,
"inter": r.inter,
"experts": r.experts,
"top_k": r.top_k,
"dtype": r.dtype,
"naive_ms": f"{r.naive_ms:.4f}",
"grouped_ms": f"{r.grouped_ms:.4f}",
"speedup": f"{r.speedup:.4f}",
"naive_tflops": f"{r.naive_tflops:.4f}",
"grouped_tflops": f"{r.grouped_tflops:.4f}",
"max_abs": f"{r.max_abs:.6e}",
"mean_abs": f"{r.mean_abs:.6e}",
"rel_l2": f"{r.rel_l2:.6e}",
}
)
if __name__ == "__main__":
import weakref
main()

View File

@@ -0,0 +1,205 @@
#!/usr/bin/env python
"""Benchmark Torchtitan MoE grouped vs naive expert execution."""
from __future__ import annotations
import argparse
import sys
import time
from pathlib import Path
import torch
# Ensure torchtitan is importable when running from the axolotl tree
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
_TITAN_PATH = _PROJECT_ROOT / "torchtitan"
if str(_TITAN_PATH) not in sys.path:
sys.path.append(str(_TITAN_PATH))
from torchtitan.models.moe import MoE, MoEArgs
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Torchtitan MoE microbenchmark")
p.add_argument("--bsz", type=int, default=8)
p.add_argument("--seq", type=int, default=1024)
p.add_argument("--hidden", type=int, default=4096)
p.add_argument("--inter", type=int, default=14336)
p.add_argument("--experts", type=int, default=8)
p.add_argument("--top_k", type=int, default=2)
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=50)
p.add_argument("--warmup", type=int, default=10)
p.add_argument("--init-std", type=float, default=0.02)
p.add_argument(
"--score-before",
action="store_true",
help="Apply routing scores before expert computation (default: after)",
)
p.add_argument(
"--score-func",
choices=["softmax", "sigmoid"],
default="softmax",
)
p.add_argument(
"--route-norm",
action="store_true",
help="Enable Torchtitan router normalization when using sigmoid scores.",
)
return p.parse_args()
def _map_dtype(arg: str) -> torch.dtype:
return {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[arg]
def _estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
# Two up projections + one down projection per expert/token combination.
return 6.0 * tokens * top_k * hidden * inter
def _prepare_module(
moe: MoE,
*,
device: torch.device,
dtype: torch.dtype,
) -> MoE:
moe = moe.to(device=device)
for param in moe.parameters():
param.data = param.data.to(dtype)
if param.grad is not None:
param.grad = None
buffers = dict(moe.named_buffers())
for name, buf in buffers.items():
if name == "tokens_per_expert":
moe._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
elif name == "expert_bias" and buf is not None:
moe._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
else:
moe._buffers[name] = buf.to(device=device, dtype=dtype)
moe.eval()
return moe
@torch.inference_mode()
def _forward_fn(module: MoE, x: torch.Tensor) -> torch.Tensor:
return module(x)
def _bench(fn, *, iters: int, warmup: int, sync: bool = True) -> float:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
for _ in range(warmup):
fn()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times = []
for _ in range(iters):
if sync and device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
fn()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000.0)
return sum(times) / len(times)
def main() -> None:
args = _parse_args()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dtype = _map_dtype(args.dtype)
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
moe_args_grouped = MoEArgs(
num_experts=args.experts,
num_shared_experts=0,
score_func=args.score_func,
route_norm=args.route_norm,
top_k=args.top_k,
use_grouped_mm=True,
score_before_experts=args.score_before,
load_balance_coeff=None,
)
moe_grouped = MoE(moe_args_grouped, dim=args.hidden, hidden_dim=args.inter)
moe_grouped.init_weights(args.init_std, buffer_device=device)
moe_args_naive = MoEArgs(
num_experts=args.experts,
num_shared_experts=0,
score_func=args.score_func,
route_norm=args.route_norm,
top_k=args.top_k,
use_grouped_mm=False,
score_before_experts=args.score_before,
load_balance_coeff=None,
)
moe_naive = MoE(moe_args_naive, dim=args.hidden, hidden_dim=args.inter)
moe_naive.load_state_dict(moe_grouped.state_dict(), strict=True)
moe_grouped = _prepare_module(moe_grouped, device=device, dtype=dtype)
moe_naive = _prepare_module(moe_naive, device=device, dtype=dtype)
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
tokens = args.bsz * args.seq
print(
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} "
f"inter={args.inter} experts={args.experts} top_k={args.top_k}"
)
def run_naive():
return _forward_fn(moe_naive, x)
def run_grouped():
return _forward_fn(moe_grouped, x)
if hasattr(moe_naive, "tokens_per_expert"):
moe_naive.tokens_per_expert.zero_()
if hasattr(moe_grouped, "tokens_per_expert"):
moe_grouped.tokens_per_expert.zero_()
t_naive = _bench(run_naive, iters=args.iters, warmup=args.warmup)
flops = _estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
tflops_naive = flops / ((t_naive / 1000.0) * 1e12)
print(
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t"
f"{tflops_naive:.2f} TFLOP/s"
)
y_naive = run_naive()
if hasattr(moe_grouped, "tokens_per_expert"):
moe_grouped.tokens_per_expert.zero_()
t_grouped = _bench(run_grouped, iters=args.iters, warmup=args.warmup)
tflops_grouped = flops / ((t_grouped / 1000.0) * 1e12)
speedup = t_naive / t_grouped if t_grouped > 0 else float("nan")
print(
f"grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000.0):.1f} tok/s\t"
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
)
y_grouped = run_grouped()
diff = (y_naive.float() - y_grouped.float()).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
print(
f"grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,328 @@
#!/usr/bin/env python
"""Sweep Torchtitan MoE grouped vs naive configurations and report performance."""
from __future__ import annotations
import argparse
import csv
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List
import torch
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
_TITAN_PATH = _PROJECT_ROOT / "torchtitan"
if str(_TITAN_PATH) not in sys.path:
sys.path.append(str(_TITAN_PATH))
from torchtitan.models.moe import MoE, MoEArgs
def _parse_int_list(value: str) -> List[int]:
return [int(v) for v in value.split(",") if v]
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Torchtitan MoE grouped vs naive sweep")
p.add_argument(
"--batch-sizes", default="4,8,16", help="Comma separated batch sizes"
)
p.add_argument(
"--seq-lens", default="1024,2048", help="Comma separated sequence lengths"
)
p.add_argument(
"--experts", default="8,16,32,64", help="Comma separated expert counts"
)
p.add_argument("--top-ks", default="1,2,4", help="Comma separated top_k choices")
p.add_argument("--hidden", type=int, default=4096)
p.add_argument("--inter", type=int, default=14336)
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=25)
p.add_argument("--warmup", type=int, default=5)
p.add_argument("--init-std", type=float, default=0.02)
p.add_argument("--score-before", action="store_true")
p.add_argument("--score-func", choices=["softmax", "sigmoid"], default="softmax")
p.add_argument("--route-norm", action="store_true")
p.add_argument("--csv", type=Path, default=None, help="Optional CSV output path")
return p.parse_args()
def _map_dtype(arg: str) -> torch.dtype:
return {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[arg]
def _estimate_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
return 6.0 * tokens * top_k * hidden * inter
def _prepare_module(module: MoE, *, device: torch.device, dtype: torch.dtype) -> MoE:
module = module.to(device=device)
for param in module.parameters():
param.data = param.data.to(dtype)
if param.grad is not None:
param.grad = None
for name, buf in module.named_buffers():
if name == "tokens_per_expert":
module._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
elif name == "expert_bias" and buf is not None:
module._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
else:
module._buffers[name] = buf.to(device=device, dtype=dtype)
module.eval()
return module
@torch.inference_mode()
def _forward(module: MoE, x: torch.Tensor) -> torch.Tensor:
return module(x)
def _bench(callable_, *, iters: int, warmup: int, device: torch.device) -> float:
for _ in range(warmup):
callable_()
if device.type == "cuda":
torch.cuda.synchronize()
timings: List[float] = []
for _ in range(iters):
if device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
callable_()
if device.type == "cuda":
torch.cuda.synchronize()
timings.append((time.perf_counter() - start) * 1000.0)
return sum(timings) / len(timings)
@dataclass
class SweepResult:
bsz: int
seq: int
experts: int
top_k: int
dtype: str
naive_ms: float
grouped_ms: float
speedup: float
naive_tflops: float
grouped_tflops: float
max_abs: float
mean_abs: float
rel_l2: float
def _run_case(
*,
bsz: int,
seq: int,
experts: int,
top_k: int,
hidden: int,
inter: int,
dtype: torch.dtype,
device: torch.device,
iters: int,
warmup: int,
init_std: float,
score_before: bool,
score_func: str,
route_norm: bool,
) -> SweepResult:
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
moe_args_grouped = MoEArgs(
num_experts=experts,
num_shared_experts=0,
score_func=score_func,
route_norm=route_norm,
top_k=top_k,
use_grouped_mm=True,
score_before_experts=score_before,
load_balance_coeff=None,
)
moe_grouped = MoE(moe_args_grouped, dim=hidden, hidden_dim=inter)
moe_grouped.init_weights(init_std, buffer_device=device)
moe_args_naive = MoEArgs(
num_experts=experts,
num_shared_experts=0,
score_func=score_func,
route_norm=route_norm,
top_k=top_k,
use_grouped_mm=False,
score_before_experts=score_before,
load_balance_coeff=None,
)
moe_naive = MoE(moe_args_naive, dim=hidden, hidden_dim=inter)
moe_naive.load_state_dict(moe_grouped.state_dict(), strict=True)
moe_grouped = _prepare_module(moe_grouped, device=device, dtype=dtype)
moe_naive = _prepare_module(moe_naive, device=device, dtype=dtype)
x = torch.randn(bsz, seq, hidden, device=device, dtype=dtype)
def run_naive():
if hasattr(moe_naive, "tokens_per_expert"):
moe_naive.tokens_per_expert.zero_()
return _forward(moe_naive, x)
def run_grouped():
if hasattr(moe_grouped, "tokens_per_expert"):
moe_grouped.tokens_per_expert.zero_()
return _forward(moe_grouped, x)
naive_ms = _bench(run_naive, iters=iters, warmup=warmup, device=device)
y_naive = run_naive()
grouped_ms = _bench(run_grouped, iters=iters, warmup=warmup, device=device)
y_grouped = run_grouped()
diff = (y_naive.float() - y_grouped.float()).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
tokens = bsz * seq
flops = _estimate_flops(tokens, hidden, inter, top_k)
naive_tflops = flops / ((naive_ms / 1000.0) * 1e12)
grouped_tflops = flops / ((grouped_ms / 1000.0) * 1e12)
speedup = naive_ms / grouped_ms if grouped_ms > 0 else float("nan")
return SweepResult(
bsz=bsz,
seq=seq,
experts=experts,
top_k=top_k,
dtype=str(dtype),
naive_ms=naive_ms,
grouped_ms=grouped_ms,
speedup=speedup,
naive_tflops=naive_tflops,
grouped_tflops=grouped_tflops,
max_abs=max_abs,
mean_abs=mean_abs,
rel_l2=rel_l2,
)
def _print_header(
hidden: int, inter: int, dtype: torch.dtype, device: torch.device
) -> None:
print(f"Device={device} dtype={dtype} hidden={hidden} inter={inter}")
print(
"bsz\tseq\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"
)
def _print_result(res: SweepResult) -> None:
print(
f"{res.bsz}\t{res.seq}\t{res.experts}\t{res.top_k}\t"
f"{res.naive_ms:.2f}\t{res.grouped_ms:.2f}\t{res.speedup:.2f}\t"
f"{res.naive_tflops:.2f}\t{res.grouped_tflops:.2f}\t"
f"{res.max_abs:.2e}\t{res.mean_abs:.2e}\t{res.rel_l2:.2e}"
)
def _write_csv(path: Path, results: Iterable[SweepResult]) -> None:
fieldnames = [
"batch_size",
"seq_len",
"experts",
"top_k",
"dtype",
"naive_ms",
"grouped_ms",
"speedup",
"naive_tflops",
"grouped_tflops",
"max_abs",
"mean_abs",
"rel_l2",
]
with path.open("w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for r in results:
writer.writerow(
{
"batch_size": r.bsz,
"seq_len": r.seq,
"experts": r.experts,
"top_k": r.top_k,
"dtype": r.dtype,
"naive_ms": f"{r.naive_ms:.4f}",
"grouped_ms": f"{r.grouped_ms:.4f}",
"speedup": f"{r.speedup:.4f}",
"naive_tflops": f"{r.naive_tflops:.4f}",
"grouped_tflops": f"{r.grouped_tflops:.4f}",
"max_abs": f"{r.max_abs:.6e}",
"mean_abs": f"{r.mean_abs:.6e}",
"rel_l2": f"{r.rel_l2:.6e}",
}
)
def main() -> None:
args = _parse_args()
dtype = _map_dtype(args.dtype)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
batch_sizes = _parse_int_list(args.batch_sizes)
seq_lens = _parse_int_list(args.seq_lens)
experts_list = _parse_int_list(args.experts)
top_ks = _parse_int_list(args.top_ks)
results: List[SweepResult] = []
_print_header(args.hidden, args.inter, dtype, device)
for bsz in batch_sizes:
for seq in seq_lens:
for experts in experts_list:
for top_k in top_ks:
try:
res = _run_case(
bsz=bsz,
seq=seq,
experts=experts,
top_k=top_k,
hidden=args.hidden,
inter=args.inter,
dtype=dtype,
device=device,
iters=args.iters,
warmup=args.warmup,
init_std=args.init_std,
score_before=args.score_before,
score_func=args.score_func,
route_norm=args.route_norm,
)
except RuntimeError as err:
print(
f"{bsz}\t{seq}\t{experts}\t{top_k}\tERROR: {err}",
file=sys.stderr,
)
continue
results.append(res)
_print_result(res)
if args.csv and results:
_write_csv(args.csv, results)
print(f"Wrote {len(results)} rows to {args.csv}")
if __name__ == "__main__":
main()

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"'
)

View File

@@ -0,0 +1,53 @@
#!/usr/bin/env python
"""Inspect Qwen2 MoE expert implementations for grouped-mm debugging."""
from __future__ import annotations
import sys
from pathlib import Path
import torch
ROOT = Path(__file__).resolve().parents[2]
sys.path.extend(
[
str(ROOT / "transformers" / "src"),
str(ROOT / "src"),
]
)
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
from axolotl.kernels.moe.torch_grouped import _iter_expert_impls
def main() -> None:
cfg = Qwen2MoeConfig(
hidden_size=4096,
moe_intermediate_size=14336,
shared_expert_intermediate_size=14336,
num_experts=32,
num_experts_per_tok=4,
)
block = Qwen2MoeSparseMoeBlock(cfg).to("cuda", dtype=torch.bfloat16)
experts = block.experts
experts._ax_parent_block = block
impls = _iter_expert_impls(experts)
print(f"impl count: {len(impls)}")
for idx, impl in enumerate(impls[:8]):
has_gate = hasattr(impl, "gate_proj")
has_up = hasattr(impl, "up_proj")
print(
f"impl[{idx}] type={impl.__class__.__name__} has_gate={has_gate} has_up={has_up}"
)
if has_gate:
print(f" gate shape {tuple(impl.gate_proj.weight.shape)}")
print(f" up shape {tuple(impl.up_proj.weight.shape)}")
print(f" down shape {tuple(impl.down_proj.weight.shape)}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python
"""
Probe PyTorch for grouped GEMM operator names and namespaces.
Run: python scripts/probe_torch_grouped_ops.py
"""
import sys
def main():
try:
import torch
except Exception as e:
print("Failed to import torch:", e)
sys.exit(1)
print("torch version:", torch.__version__)
namespaces = [n for n in dir(torch.ops) if not n.startswith("_")]
print("ops namespaces:", namespaces)
found_any = False
for ns in namespaces:
obj = getattr(torch.ops, ns, None)
ops = []
if obj is not None:
try:
ops = dir(obj)
except Exception as e:
print(f"warning: failed to list ops for namespace {ns}: {e}")
cands = [
o
for o in ops
if ("group" in o.lower())
or ("mm_grouped" in o.lower())
or ("matmul_grouped" in o.lower())
or ("grouped" in o.lower())
]
if cands:
found_any = True
print(f"namespace {ns} candidates:", cands)
if not found_any:
print("No grouped GEMM candidates found. PyTorch >= 2.8 is recommended.")
if __name__ == "__main__":
main()

View File

@@ -64,7 +64,9 @@ def parse_requirements(extras_require_map):
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 7):
if (major, minor) >= (2, 8):
pass
elif (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.30")
@@ -125,7 +127,7 @@ extras_require = {
"yunchang==0.6.0",
],
"deepspeed": [
"deepspeed==0.17.2",
"deepspeed==0.17.5",
"deepspeed-kernels",
],
"mamba-ssm": [
@@ -160,6 +162,7 @@ extras_require = {
"llmcompressor": [
"llmcompressor==0.5.1",
],
"fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
}
install_requires, dependency_links, extras_require_build = parse_requirements(
extras_require

View File

@@ -4,5 +4,7 @@ import os
from axolotl.logging_config import configure_logging
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
configure_logging()

View File

@@ -14,9 +14,13 @@ class PreprocessCliArgs:
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(
default=None,
default=False,
metadata={
"help": "Use IterableDataset for streaming processing of large datasets"
"help": (
"Deprecated in v0.13.0, will be removed in v0.14.0. For streaming "
"datasets, use 'axolotl train' and set 'streaming: true' in your YAML "
"config, or pass --streaming instead in the CLI."
)
},
)
@@ -111,6 +115,7 @@ class QuantizeCliArgs:
quantize_embedding: Optional[bool] = field(default=None)
group_size: Optional[int] = field(default=None)
output_dir: Optional[str] = field(default=None)
hub_model_id: Optional[str] = field(default=None)
@dataclass

View File

@@ -7,6 +7,8 @@ from typing import Literal
import yaml
from axolotl.cli.cloud.base import Cloud
from axolotl.cli.cloud.baseten import BasetenCloud
from axolotl.cli.cloud.modal_ import ModalCloud
from axolotl.utils.dict import DictDefault
@@ -38,8 +40,15 @@ def do_cli_train(
cwd=None,
**kwargs,
) -> None:
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
cloud_cfg: DictDefault = load_cloud_cfg(cloud_config)
provider = cloud_cfg.provider or "modal"
cloud: Cloud | None
if provider == "modal":
cloud = ModalCloud(cloud_cfg)
elif provider == "baseten":
cloud = BasetenCloud(cloud_cfg.to_dict())
else:
raise ValueError(f"Unsupported cloud provider: {provider}")
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
local_dirs = {}

View File

@@ -0,0 +1,48 @@
"""Baseten Cloud CLI"""
import shutil
import subprocess # nosec B404
import tempfile
from os.path import dirname
from typing import Literal
import yaml
from axolotl.cli.cloud.base import Cloud
class BasetenCloud(Cloud):
"""Baseten Cloud Axolotl CLI"""
def __init__(self, config: dict):
self.config = config
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
raise NotImplementedError(
"Separate preprocess function for Baseten is not "
"implemented and will happen during hte train step."
)
def train(
self,
config_yaml: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
local_dirs: dict[str, str] | None = None, # pylint: disable=unused-argument
**kwargs,
):
with tempfile.TemporaryDirectory() as tmp_dir:
config = self.config.copy()
config["launcher"] = launcher
config["launcher_args"] = launcher_args
with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout:
yaml.dump(config, cloud_fout)
with open(tmp_dir + "/train.yaml", "w", encoding="utf-8") as config_fout:
config_fout.write(config_yaml)
shutil.copyfile(dirname(__file__) + "/template/run.sh", tmp_dir + "/run.sh")
shutil.copyfile(
dirname(__file__) + "/template/train_sft.py", tmp_dir + "/train_sft.py"
)
subprocess.run( # nosec B603 B607
["truss", "train", "push", "train_sft.py"], cwd=tmp_dir, check=False
)

View File

@@ -0,0 +1,9 @@
#!/bin/bash
set -eux
export NCCL_SOCKET_IFNAME="^docker0,lo"
export NCCL_IB_DISABLE=0
export NCCL_TIMEOUT=1800000
axolotl preprocess train.yaml
axolotl train train.yaml --launcher ${AXOLOTL_LAUNCHER} ${AXOLOTL_LAUNCHER_ARGS}

View File

@@ -0,0 +1,71 @@
"""
Baseten Training Script for Axolotl
"""
# pylint: skip-file
import yaml
from truss.base import truss_config
# Import necessary classes from the Baseten Training SDK
from truss_train import definitions
cloud_config = yaml.safe_load(open("cloud.yaml", "r"))
gpu = cloud_config.get("gpu", "h100")
gpu_count = int(cloud_config.get("gpu_count", 1))
node_count = int(cloud_config.get("node_count", 1))
project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project"
secrets = cloud_config.get("secrets", [])
launcher = cloud_config.get("launcher", "accelerate")
launcher_args = cloud_config.get("launcher_args", [])
script_name = "run.sh"
launcher_args_str = ""
if launcher_args:
launcher_args_str = "-- " + " ".join(launcher_args)
# 1. Define a base image for your training job
# must use torch 2.7.0 for vllm
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
# 2. Define the Runtime Environment for the Training Job
# This includes start commands and environment variables.a
# Secrets from the baseten workspace like API keys are referenced using
# `SecretReference`.
env_vars = {
"AXOLOTL_LAUNCHER": launcher,
"AXOLOTL_LAUNCHER_ARGS": launcher_args_str,
}
for secret_name in secrets:
env_vars[secret_name] = definitions.SecretReference(name=secret_name)
training_runtime = definitions.Runtime(
start_commands=[ # Example: list of commands to run your training script
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
],
environment_variables=env_vars,
)
# 3. Define the Compute Resources for the Training Job
training_compute = definitions.Compute(
node_count=node_count,
accelerator=truss_config.AcceleratorSpec(
accelerator=truss_config.Accelerator.H100,
count=gpu_count,
),
)
# 4. Define the Training Job
# This brings together the image, compute, and runtime configurations.
my_training_job = definitions.TrainingJob(
image=definitions.Image(base_image=BASE_IMAGE),
compute=training_compute,
runtime=training_runtime,
)
# This config will be pushed using the Truss CLI.
# The association of the job to the project happens at the time of push.
first_project_with_job = definitions.TrainingProject(
name=project_name, job=my_training_job
)

View File

@@ -23,7 +23,8 @@ from axolotl.utils.config import (
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.tee import prepare_debug_log
from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = get_logger(__name__)
@@ -227,8 +228,11 @@ def load_cfg(
},
)
# NOTE(djsaunde): We start outputting to output_dir/debug.log at this point since we
# have to wait for cfg.output to be resolved. We could call this earlier if we write
# to a temporary file, and then move it later.
prepare_debug_log(cfg)
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg)
normalize_cfg_datasets(cfg)
setup_wandb_env_vars(cfg)
@@ -241,7 +245,6 @@ def load_cfg(
for k, v in cfg.items()
if v is not None
}
LOG.info(
"config:\n%s",
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),

View File

@@ -14,10 +14,12 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.cli.args import InferenceCliArgs
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
from axolotl.cli.utils.diffusion import (
diffusion_inference,
launch_diffusion_gradio_ui,
)
from axolotl.integrations.base import PluginManager
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -32,6 +34,7 @@ def get_multi_line_input() -> str:
Possibly multi-line, possibly empty stdin input as a string.
"""
print("Give me an instruction (Ctrl + D to submit): ")
print("=" * 80)
instruction = ""
for line in sys.stdin:
@@ -46,9 +49,9 @@ def do_inference(
cli_args: InferenceCliArgs,
):
"""
Runs inference on the command line in a loop. User input is accepted, a chat template
is (optionally) applied, and the model specified in the `axolotl` config is used to
generate completions according to a default generation config.
Runs inference on the command line in a loop. User input is accepted, a chat
template is (optionally) applied, and the model specified in the `axolotl` config is
used to generate completions according to a default generation config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
@@ -64,17 +67,31 @@ def do_inference(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg, ds_cfg=None, tokenizer=tokenizer
)
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
# Detect diffusion mode
plugin_manager = PluginManager.get_instance()
is_diffusion = any(
plugin.__class__.__name__ == "DiffusionPlugin"
for plugin in plugin_manager.plugins.values()
)
if is_diffusion:
print("=" * 80)
print("Commands:")
print(":complete N -> completion mode with N tokens (default 64)")
print(":mask R -> random masking with ratio R (0.01.0)")
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
@@ -104,9 +121,19 @@ def do_inference(
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
print("=" * 80)
model.eval()
with torch.no_grad():
if is_diffusion:
diffusion_inference(
model=model,
tokenizer=tokenizer,
cfg=cfg,
prompt=prompt,
chat_template_str=chat_template_str,
)
continue
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
@@ -129,7 +156,7 @@ def do_inference(
generation_config=generation_config,
streamer=streamer,
)
print("=" * 40)
print("=" * 80)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
@@ -159,10 +186,33 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
chat_template_str = get_chat_template_from_config(
cfg, ds_cfg=None, tokenizer=tokenizer
)
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
# Detect diffusion mode
plugin_manager = PluginManager.get_instance()
is_diffusion = any(
plugin.__class__.__name__ == "DiffusionPlugin"
for plugin in plugin_manager.plugins.values()
)
if is_diffusion:
launch_diffusion_gradio_ui(
model=model,
tokenizer=tokenizer,
cfg=cfg,
prompter_module=prompter_module,
chat_template_str=chat_template_str,
)
return
def generate(instruction):
if not instruction:
return

View File

@@ -26,7 +26,7 @@ from axolotl.cli.utils import (
launch_training,
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import patch_optimized_env
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import AxolotlInputConfig
@@ -44,7 +44,7 @@ def cli():
"""Axolotl CLI - Train and fine-tune large language models"""
print_axolotl_text_art()
load_dotenv()
patch_optimized_env()
set_pytorch_cuda_alloc_conf()
@cli.command()

View File

@@ -43,7 +43,10 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
tokenizer.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
if processor:
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))

View File

@@ -35,10 +35,20 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
check_accelerate_default_config()
check_user_token()
if cli_args.iterable:
LOG.error(
"The --iterable CLI argument for 'axolotl preprocess' is no longer "
"supported. For training, set 'streaming: true' in your YAML config or "
"pass '--streaming' in your 'axolotl train' command for on-the-fly "
"preprocessing."
)
return
for key in ["skip_prepare_dataset", "pretraining_dataset"]:
if cfg.get(key):
LOG.error(
f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead."
f"You have set `{key}:`. `preprocess` is not needed. Run the 'axolotl "
"train' CLI directly instead."
)
return

View File

@@ -5,12 +5,17 @@ CLI to post-training quantize a model using torchao
from pathlib import Path
from typing import Union
from transformers import AutoModelForCausalLM
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer
from axolotl.utils.logging import get_logger
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
from axolotl.utils.quantization import (
TorchAOQuantDType,
get_quantization_config,
quantization_config_to_str,
quantize_model,
)
LOG = get_logger(__name__)
@@ -43,13 +48,13 @@ def do_quantize(
"No quantization configuration found. Please specify either qat or quantization in your config file."
)
model_path = cli_args.get("model_path") or cfg.output_dir
model_path = cli_args.get("base_model") or cfg.output_dir
if weight_dtype := cli_args.get("weight_dtype"):
weight_dtype = TorchIntDType[weight_dtype]
weight_dtype = TorchAOQuantDType.from_string(weight_dtype)
else:
weight_dtype = quantize_cfg.weight_dtype
if activation_dtype := cli_args.get("activation_dtype"):
activation_dtype = TorchIntDType[activation_dtype]
activation_dtype = TorchAOQuantDType.from_string(activation_dtype)
else:
activation_dtype = quantize_cfg.activation_dtype
group_size = cli_args.get("group_size") or quantize_cfg.group_size
@@ -57,10 +62,15 @@ def do_quantize(
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
)
output_dir = cli_args.get("output_dir") or cfg.output_dir
hub_model_id = cli_args.get("hub_model_id") or cfg.hub_model_id
LOG.info(f"Loading model from {model_path}...")
LOG.info(f"Loading model from {model_path}.")
tokenizer = load_tokenizer(cfg)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
config = AutoConfig.from_pretrained(model_path)
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", torch_dtype=torch_dtype
)
LOG.info(
f"Quantizing model with configuration: \n"
@@ -70,11 +80,21 @@ def do_quantize(
f"\tquantize_embedding: {quantize_embedding}"
)
quantize_model_for_ptq(
quantize_model(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...")
quantization_config = get_quantization_config(
weight_dtype, activation_dtype, group_size
)
ao_config = TorchAoConfig(
quant_type=quantization_config,
include_input_output_embeddings=quantize_embedding,
)
model.config.quantization_config = ao_config
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
model.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
@@ -84,5 +104,16 @@ def do_quantize(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")
if hub_model_id:
hub_model_id = (
hub_model_id.rstrip("-")
+ f"-{quantization_config_to_str[type(quantization_config)]}"
)
model.push_to_hub(hub_model_id, safe_serialization=False)
tokenizer.push_to_hub(hub_model_id)
LOG.info(f"Quantized model pushed to: {hub_model_id}.")
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")

View File

@@ -17,6 +17,7 @@ from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import prepare_optim_env
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
@@ -59,7 +60,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
@@ -92,6 +92,7 @@ def ray_train_func(kwargs: dict):
# cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict)
# also renormalize the config now that TorchTrainer has spawned distributed workers
cfg = DictDefault(kwargs["cfg"])
prepare_optim_env(cfg)
normalize_config(cfg)
# now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype

View File

@@ -0,0 +1,374 @@
"""Helpers for diffusion-mode inference in CLI and Gradio."""
from __future__ import annotations
import gradio as gr
from colorama import Fore, Style
from axolotl.integrations.diffusion import generate, resolve_mask_token_id
from axolotl.utils.dict import DictDefault
def diffusion_inference(
model,
tokenizer,
cfg,
prompt: str,
chat_template_str: str | None = None,
):
"""Diffusion inference helper method."""
mode = "random"
completion_tokens = 0
target_mask_ratio = None
mode, completion_tokens, target_mask_ratio, cleaned = _parse_commands(prompt)
if cleaned:
prompt = cleaned
info = run_diffusion(
model=model,
tokenizer=tokenizer,
cfg=cfg,
prompt=prompt,
chat_template_str=chat_template_str,
mode=mode,
target_mask_ratio=target_mask_ratio,
completion_tokens=completion_tokens,
)
masked_text = info["masked_text"]
mask_ratio = info["mask_ratio"]
generated_ids = info["generated_ids"]
masked_positions = info["masked_positions"]
orig_ids = info["orig_ids"]
# Display with masked preview and colored diff
if masked_text is not None and mask_ratio is not None:
print(f"Masked ({mask_ratio:.1%}):\n{masked_text}\n")
if generated_ids is not None:
# Compute per-token style
styles: list[str] = []
for i, tid in enumerate(generated_ids):
if i in masked_positions:
if i < len(orig_ids) and tid == orig_ids[i]:
styles.append("green") # correct fill
elif i < len(orig_ids):
styles.append("red") # incorrect fill
else:
styles.append("normal") # appended
else:
same = i < len(orig_ids) and tid == orig_ids[i]
styles.append("dim" if same else "normal")
# Group contiguous spans by style
styled_spans: list[tuple[str, int, int]] = []
if generated_ids:
current_style = styles[0]
start = 0
for i in range(1, len(generated_ids)):
s = styles[i]
if s != current_style:
styled_spans.append((current_style, start, i))
current_style, start = s, i
styled_spans.append((current_style, start, len(generated_ids)))
out_parts = []
for style_name, a, b in styled_spans:
chunk_text = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False)
if style_name == "green":
out_parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL)
elif style_name == "red":
out_parts.append(Fore.RED + chunk_text + Style.RESET_ALL)
else:
if style_name == "dim":
out_parts.append(Style.DIM + chunk_text + Style.RESET_ALL)
else:
out_parts.append(chunk_text)
print("Generated:\n" + "".join(out_parts))
else:
print("Generated:\n(no output)")
def _parse_commands(text: str):
"""
Parse leading diffusion commands.
Supported at start of input (can be chained):
:complete N -> completion mode with N tokens (default 64)
:mask R -> random masking with ratio R in [0, 1]
"""
tokens = text.strip().split()
i = 0
mode = "random"
completion_tokens = 0
target_mask_ratio = None
consumed = 0
while i < len(tokens) and tokens[i].startswith(":"):
cmd = tokens[i]
i += 1
consumed = i
if cmd == ":complete":
mode = "completion"
if i < len(tokens):
try:
completion_tokens = int(tokens[i])
i += 1
consumed = i
except Exception:
completion_tokens = 64
else:
completion_tokens = 64
elif cmd == ":mask":
mode = "random"
if i < len(tokens):
try:
target_mask_ratio = float(tokens[i])
i += 1
consumed = i
except Exception:
target_mask_ratio = None
else:
i -= 1
consumed = i
break
cleaned = " ".join(tokens[consumed:])
return mode, completion_tokens, target_mask_ratio, cleaned
def run_diffusion(
*,
model,
tokenizer,
cfg: DictDefault,
prompt: str,
chat_template_str: str | None,
mode: str = "random",
target_mask_ratio: float | None = None,
completion_tokens: int = 0,
):
"""Run a single diffusion generation and return a structured result dict."""
if chat_template_str:
batch = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
mask_token_id = resolve_mask_token_id(tokenizer, cfg, allow_add=False)
seq = batch["input_ids"].to(cfg.device)
gen_mode = "completion" if mode == "completion" else "random"
comp_tokens = int(completion_tokens) if gen_mode == "completion" else 0
result = generate(
model,
tokenizer,
original_sequence=seq[:1],
num_diffusion_steps=cfg.diffusion.num_diffusion_steps,
temperature=cfg.diffusion.generation_temperature,
mask_token_id=int(mask_token_id),
mode=gen_mode, # type: ignore[arg-type]
completion_tokens=comp_tokens,
target_mask_ratio=target_mask_ratio,
)
masked_text = result.get("masked") if isinstance(result, dict) else None
mask_ratio = result.get("mask_ratio") if isinstance(result, dict) else None
generated_ids = result.get("generated_ids") if isinstance(result, dict) else None
masked_positions = (
set(result.get("masked_positions") or []) if isinstance(result, dict) else set()
)
orig_ids = seq[0].detach().cpu().tolist()
return {
"masked_text": masked_text,
"mask_ratio": mask_ratio,
"generated_ids": generated_ids,
"masked_positions": masked_positions,
"orig_ids": orig_ids,
}
def render_html(
*,
generated_ids: list[int] | None,
orig_ids: list[int],
masked_positions: set[int],
tokenizer,
) -> str:
"""Render HTML visualizing diffusion outputs."""
if not generated_ids:
return "<pre>Generated:\n(no output)</pre>"
def _style_for(i: int, tid: int) -> str:
if i in masked_positions:
if i < len(orig_ids) and tid == orig_ids[i]:
return "green"
if i < len(orig_ids):
return "red"
return "normal"
same = i < len(orig_ids) and tid == orig_ids[i]
return "dim" if same else "normal"
# Group contiguous spans by style to reduce HTML size
spans: list[tuple[str, int, int]] = []
if generated_ids:
cur = _style_for(0, generated_ids[0])
start = 0
for i in range(1, len(generated_ids)):
s = _style_for(i, generated_ids[i])
if s != cur:
spans.append((cur, start, i))
cur, start = s, i
spans.append((cur, start, len(generated_ids)))
html_parts = []
for style_name, a, b in spans:
txt = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False)
if style_name == "green":
html_parts.append(f'<span style="color:#2e7d32">{txt}</span>')
elif style_name == "red":
html_parts.append(f'<span style="color:#c62828">{txt}</span>')
elif style_name == "dim":
html_parts.append(f'<span style="opacity:0.6">{txt}</span>')
else:
html_parts.append(txt)
legend = (
'<div style="font-size:0.9em;margin-bottom:4px">'
'<span style="color:#2e7d32">correct</span>, '
'<span style="color:#c62828">incorrect</span>, '
'<span style="opacity:0.6">unchanged</span>'
"</div>"
)
return (
legend
+ '<pre style="white-space:pre-wrap">Generated:\n'
+ "".join(html_parts)
+ "</pre>"
)
def launch_diffusion_gradio_ui(
*,
model,
tokenizer,
cfg: DictDefault,
prompter_module=None,
chat_template_str: str | None = None,
):
"""Build and launch a simple Gradio UI for diffusion inference."""
with gr.Blocks(
title=cfg.get("gradio_title", "Axolotl Diffusion Interface")
) as demo:
gr.Markdown(
"""
## Axolotl Diffusion Inference
- Mode "Random" masks tokens at a target ratio and fills them.
- Mode "Completion" appends N masked tokens at the end and fills them.
"""
)
with gr.Row():
mode = gr.Radio(
choices=["random", "completion"],
value="random",
label="Mode",
)
mask_ratio = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.4,
label="Mask ratio (random mode)",
interactive=True,
)
completion_tokens = gr.Number(
value=64,
precision=0,
label="Completion tokens (completion mode)",
interactive=True,
visible=False,
)
instruction = gr.Textbox(label="Instruction", lines=6)
run_btn = gr.Button("Generate")
masked_preview = gr.Textbox(label="Masked preview", lines=6)
html_out = gr.HTML(label="Generated")
def _toggle_controls(selected_mode: str):
return (
gr.update(visible=(selected_mode == "random")),
gr.update(visible=(selected_mode == "completion")),
)
mode.change(
_toggle_controls,
inputs=[mode],
outputs=[mask_ratio, completion_tokens],
)
def _gen(instruction_text: str, selected_mode: str, mratio: float, ctoks: int):
if not instruction_text:
return "", "<pre>Generated:\n(no output)</pre>"
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(
instruction=instruction_text.strip("\n")
)
)
else:
prompt = instruction_text.strip()
info = run_diffusion(
model=model,
tokenizer=tokenizer,
cfg=cfg,
prompt=prompt,
chat_template_str=chat_template_str,
mode=selected_mode,
target_mask_ratio=mratio if selected_mode == "random" else None,
completion_tokens=int(ctoks) if selected_mode == "completion" else 0,
)
masked_text = info.get("masked_text")
mask_ratio_val = info.get("mask_ratio")
generated_ids = info.get("generated_ids")
masked_positions = info.get("masked_positions") or set()
orig_ids = info.get("orig_ids") or []
preview = (
f"Masked ({mask_ratio_val:.1%}):\n{masked_text}"
if masked_text is not None and mask_ratio_val is not None
else ""
)
html = render_html(
generated_ids=generated_ids,
orig_ids=orig_ids,
masked_positions=masked_positions,
tokenizer=tokenizer,
)
return preview, html
run_btn.click(
_gen,
inputs=[instruction, mode, mask_ratio, completion_tokens],
outputs=[masked_preview, html_out],
)
demo.queue().launch(
show_api=False,
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),
)

View File

@@ -55,13 +55,11 @@ def load_datasets(
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = getattr(cli_args, "iterable", False)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
cfg,
tokenizer,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if (

View File

@@ -24,9 +24,7 @@ from pathlib import Path
from typing import Any
import torch
from transformers import (
TrainerCallback,
)
from transformers import TrainerCallback
from transformers.trainer_pt_utils import AcceleratorConfig
from axolotl.integrations.base import PluginManager
@@ -437,7 +435,7 @@ class TrainerBuilderBase(abc.ABC):
# don't use the HF gradient checkpointing, manually wrap
training_args_kwargs["gradient_checkpointing"] = False
training_args_kwargs["activation_offloading"] = True
elif self.cfg.gradient_checkpointing:
elif self.cfg.gradient_checkpointing is not None:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)
@@ -512,6 +510,7 @@ class TrainerBuilderBase(abc.ABC):
self.cfg.eval_batch_size
)
training_args_kwargs["include_tkps"] = self.cfg.include_tkps
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs

View File

@@ -10,6 +10,7 @@ import transformers
from transformers import (
DataCollatorWithFlattening,
EarlyStoppingCallback,
Trainer,
)
from trl.trainer.utils import RewardDataCollatorWithPadding
@@ -35,6 +36,7 @@ from axolotl.utils.callbacks import (
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
@@ -74,6 +76,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.qat:
callbacks.append(QATCallback(self.cfg.qat))
if self.cfg.include_tkps:
callbacks.append(
TokensPerSecondCallback(
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
)
)
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -340,6 +348,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
if self.cfg.center_rewards_coefficient is not None:
training_arguments_kwargs["center_rewards_coefficient"] = (
self.cfg.center_rewards_coefficient
)
elif self.cfg.process_reward_model:
training_args_cls = AxolotlPRMConfig
else:
@@ -383,10 +395,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
**data_collator_kwargs,
)
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters:
if "processing_class" in sig.parameters or issubclass(trainer_cls, Trainer):
trainer_kwargs["processing_class"] = self.tokenizer
elif "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer
if (
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
and self.cfg.datasets is not None
@@ -404,6 +417,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
# if the trainer has the `axolotl_cfg` property, set it
if hasattr(trainer, "axolotl_cfg"):
trainer.axolotl_cfg = self.cfg
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)

View File

@@ -42,12 +42,20 @@ from axolotl.core.trainers.utils import (
)
from axolotl.utils import get_not_null
from axolotl.utils.bench import get_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__)
REDUCTION_FNS = {
"mean": torch.mean,
"min": torch.min,
"max": torch.max,
"sum": torch.sum,
}
class AxolotlTrainer(
PackingMixin,
@@ -63,6 +71,15 @@ class AxolotlTrainer(
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
_axolotl_cfg: DictDefault | None = None
@property
def axolotl_cfg(self):
return self._axolotl_cfg
@axolotl_cfg.setter
def axolotl_cfg(self, cfg):
self._axolotl_cfg = cfg
def __init__(
self,
@@ -78,9 +95,10 @@ class AxolotlTrainer(
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
self._stored_metrics = defaultdict(
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
)
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
@@ -327,6 +345,17 @@ class AxolotlTrainer(
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
# track number of tokens for tokens per second calculation
if self.args.include_tkps:
inputs_key = "labels" if "labels" in inputs else "input_ids"
if hasattr(self.state, "num_tokens"):
self.state.num_tokens = (
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
)
else:
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
@@ -342,6 +371,11 @@ class AxolotlTrainer(
num_items_in_batch=num_items_in_batch,
)
@override
def evaluate(self, *args, **kwargs):
LOG.info("Running evaluation step...")
return super().evaluate(*args, **kwargs)
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {}
@@ -526,9 +560,6 @@ class AxolotlTrainer(
super().create_accelerator_and_postprocess()
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
if self.is_fsdp_enabled:
if (
"limit_all_gathers" in self.args.fsdp_config
@@ -568,29 +599,61 @@ class AxolotlTrainer(
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
for key, metric_data in self._stored_metrics[train_eval].items():
values = torch.tensor(metric_data["values"]) # type: ignore[arg-type]
reduction_type = metric_data["reduction"]
fn = REDUCTION_FNS.get(reduction_type)
if fn is None:
raise NotImplementedError(
"Metric reduction must be one of [mean, min, max, sum]"
)
logs[key] = round(fn(values).item(), 4)
if is_main_process():
# Add memory usage
try:
active, allocated, reserved = get_gpu_memory_usage()
logs["memory/max_mem_active(gib)"] = round(active, 2)
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
logs["memory/max_active (GiB)"] = round(active, 2)
logs["memory/max_allocated (GiB)"] = round(allocated, 2)
logs["memory/device_reserved (GiB)"] = round(reserved, 2)
except (ValueError, TypeError, FileNotFoundError):
pass
if self.args.include_tkps and train_eval == "train":
# each rank will log its own tokens per second
# for logging_steps > 1 we obtain a moving average of this metric
logs["tokens_per_second_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
del self._stored_metrics[train_eval]
return super().log(logs, start_time)
def store_metrics(
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
self,
metrics: dict[str, float] | dict[str, tuple[int | float, str]],
train_eval: Literal["train", "eval"] = "train",
reduction: Literal["mean", "min", "max", "sum"] = "mean",
) -> None:
"""
Store metrics with specified reduction type.
Args:
metrics: Dictionary of metric names to values, or metric names to (value,
reduction_type) tuples.
train_eval: Whether this is for training or evaluation.
"""
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
if isinstance(value, tuple):
value, _reduction = value # type: ignore[assignment]
else:
value, _reduction = value, reduction
self._stored_metrics[train_eval][key]["values"].append(value)
self._stored_metrics[train_eval][key]["reduction"] = _reduction
def _save_checkpoint(self, model, trial, **kwargs):
# make sure the checkpoint dir exists, since trainer is flakey
@@ -657,6 +720,11 @@ class AxolotlTrainer(
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
self.data_collator.tokenizer.save_pretrained(output_dir)
save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -49,6 +49,12 @@ class AxolotlTrainingMixins:
default=False,
metadata={"help": "Use real batches for efficient training."},
)
include_tkps: bool = field(
default=True,
metadata={
"help": "Whether to include tokens per second in the training metrics."
},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},

View File

@@ -1,18 +1,17 @@
"""Module containing Dataset functionality"""
"""
Module containing dataset functionality.
We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
concept of middlewares to wrap each dataset. We'll use the collators later on to pad the
datasets.
"""
import torch
from datasets import Dataset, IterableDataset
from axolotl.utils.logging import get_logger
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
LOG = get_logger(__name__)
@@ -86,133 +85,3 @@ def wrap_dataset_for_tokenized_prompt(
**map_kwargs,
)
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
"""Iterable dataset that returns constant length chunks of tokens from stream of
text files.
Args:
tokenizer: The processor used for processing the data.
dataset: Dataset with text files.
seq_length: Length of token sequences to return.
"""
def __init__(
self,
tokenizer,
datasets,
seq_length=2048,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.datasets: list[IterableDataset] = datasets
self.seq_length = seq_length
vocab_size = len(tokenizer.get_vocab())
if vocab_size <= torch.iinfo(torch.int16).max:
self.tokens_dtype = torch.int16
elif vocab_size <= torch.iinfo(torch.int32).max:
self.tokens_dtype = torch.int32
else:
self.tokens_dtype = torch.int64
def __iter__(self):
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
for dataset in self.datasets:
idx = 0
iterator = iter(dataset)
more_examples = True
while more_examples:
try:
example = next(iterator)
idx += 1
except StopIteration:
more_examples = False
example = None
add_concat_token = False
if example:
example_len = len(example["input_ids"])
add_concat_token = example["input_ids"][-1] != self.concat_token_id
else:
example_len = 0
if not example_len or (
buffer_len + int(add_concat_token) + example_len > self.seq_length
):
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
: self.seq_length
]
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
):
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
LOG.warning(
"Dropping batch due to tensor size mismatch "
f"input_ids: {input_ids.size()}, "
f"labels: {labels.size()}, "
f"attention_mask: {attention_mask.size()}"
)
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
idx = 1
if example:
# FIXME
# just going to drop data points that are too long
if len(example["input_ids"]) <= self.seq_length:
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if add_concat_token:
input_ids.append(self.concat_token_id)
attention_mask.append(1)
labels.append(self.concat_token_id)
input_ids_with_concat = torch.tensor(
input_ids, dtype=self.tokens_dtype
)
attention_mask_with_concat = torch.tensor(
[idx * m for m in attention_mask], dtype=torch.int16
)
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids)

View File

@@ -142,7 +142,7 @@ class BasePlugin:
model: The loaded model.
"""
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None:
"""Returns a custom class for the trainer.
Args:

View File

@@ -20,8 +20,8 @@ from typing import Any, Dict, List, Type
from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
AxolotlInputConfig as AxolotlInputConfigBase,
)
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
def merge_input_args():

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"
```
## Usage
@@ -34,6 +34,7 @@ plugins:
- arcee
- cohere
- cohere2
- deepseek_v3
- gemma
- gemma2
- gemma3
@@ -42,6 +43,7 @@ plugins:
- gemma3n_text
- glm
- glm4
- glm4_moe
- gpt_oss
- granite
- granitemoe
@@ -64,6 +66,7 @@ plugins:
- qwen3
- qwen3_moe
- smollm3
- seed_oss
- voxtral
## Citation

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"`'
)

View File

@@ -0,0 +1,154 @@
# Diffusion LM Training Plugin for Axolotl
This plugin enables diffusion language model training using an approach inspired by
LLaDA (Large Language Diffusion Models) within Axolotl.
## Overview
LLaDA is a diffusion-based approach to language model training that uses:
- **Random token masking** during training instead of next-token prediction
- **Bidirectional attention** to allow the model to attend to the full context
- **Importance weighting** based on masking probabilities for stable training
This approach can lead to more robust language models with better understanding of
bidirectional context.
## Installation
The plugin is included with Axolotl. See our
[installation docs](https://docs.axolotl.ai/docs/installation.html).
## Quickstart
Train with an example config (Llama3.2 1B):
- Pretrain: `axolotl train examples/llama-3/diffusion-3.2-1b-pretrain.yaml`
- SFT: `axolotl train examples/llama-3/diffusion-3.2-1b-sft.yaml`
### Basic Configuration
You can also modify your existing configs to enable / customize diffusion training.
Add the following to your Axolotl config:
```yaml
# Enable diffusion LM training plugin
plugins:
- axolotl.integrations.diffusion.DiffusionPlugin
```
And, configure the nested `diffusion` block (defaults shown):
```yaml
diffusion:
noise_schedule: linear # or "cosine"
min_mask_ratio: 0.1
max_mask_ratio: 0.9
num_diffusion_steps: 128
eps: 1e-3
importance_weighting: true
# Mask token (training auto-adds if missing, avoid pad/eos)
mask_token_str: "<|diffusion_mask|>"
# Or use an existing special token id (e.g., 128002 for Llama-3.x)
# mask_token_id: 128002
# Sample generation during training (optional)
generate_samples: true
generation_interval: 100
num_generation_samples: 3
generation_steps: 128
generation_temperature: 0.0
generation_max_length: 100
```
## Supported Models
Any models that support 4D attention masks should work out of the box. If not, please
create an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues) or open a
[PR](https://github.com/axolotl-ai-cloud/axolotl/compare)!
## How It Works
### Random Masking
During training, tokens are randomly masked:
- Sample timestep `t` uniformly from [0, 1]
- Calculate masking probability: `p = (1 - eps) * t + eps`
- Randomly mask tokens with probability `p`
### Diffusion Loss
Loss is computed only on masked tokens with (optional) importance weighting:
```python
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
```
## Sample Generation
When `diffusion.generate_samples: true`, the plugin generates samples during training:
```
Sample 1:
Original (45 tokens): The quick brown fox jumps over the lazy dog...
Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]...
Generated: The quick brown fox jumps over the lazy dog...
```
Samples are logged to console and wandb (if enabled).
## Inference
Diffusion inference is integrated into the standard Axolotl CLI. Use the same config
you trained with and run:
```
axolotl inference path/to/your-config.yaml
```
Optionally, pass `--gradio` to use a simple web interface.
Interactive controls (prefix the prompt with commands):
- `:complete N` → completion mode with N new masked tokens appended (default 64)
- `:mask R` → random masking mode with target mask ratio R in [0.0, 1.0]
Example session:
```
================================================================================
Commands:
:complete N -> completion mode with N tokens (default 64)
:mask R -> random masking with ratio R (0.01.0)
================================================================================
Give me an instruction (Ctrl + D to submit):
:mask 0.4 The quick brown fox jumps over the lazy dog
Masked (40.0%):
The [MASK] brown [MASK] jumps over the [MASK] dog
Generated:
The quick brown fox jumps over the loud dog
```
## Metrics and Monitoring
The plugin adds (or modifies) several metrics to track diffusion training:
- `train/loss`: Weighted diffusion loss
- `train/accuracy`: Accuracy on masked tokens
- `train/mask_ratio`: Average fraction of tokens masked
- `train/num_masked_tokens`: Number of tokens masked
- `train/avg_p_mask`: Average masking probability
- `train/ce_loss`: Unweighted cross-entropy loss
- `train/importance_weight_avg`: Average importance weight
## Limitations
- No flash attention support
- No RL training support
## References
- [LLaDA Paper](https://arxiv.org/abs/2404.10406)
- [Axolotl Documentation](https://docs.axolotl.ai/)
- [API reference for plugin](https://docs.axolotl.ai/docs/api/integrations.diffusion.args.html#axolotl.integrations.diffusion.args)

View File

@@ -0,0 +1,19 @@
"""Diffusion LM training plugin init."""
from .args import DiffusionArgs, DiffusionConfig
from .callbacks import DiffusionGenerationCallback
from .generation import generate
from .plugin import DiffusionPlugin
from .trainer import DiffusionTrainer
from .utils import create_bidirectional_attention_mask, resolve_mask_token_id
__all__ = [
"DiffusionArgs",
"DiffusionPlugin",
"DiffusionTrainer",
"generate",
"resolve_mask_token_id",
"create_bidirectional_attention_mask",
"DiffusionGenerationCallback",
"DiffusionConfig",
]

View File

@@ -0,0 +1,95 @@
"""Config args for diffusion LM training (nested under `diffusion:`)."""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, Field, model_validator
class DiffusionConfig(BaseModel):
"""Nested diffusion configuration available under the `diffusion` key."""
# Noise schedule config
noise_schedule: Literal["linear", "cosine"] = Field(
default="linear", description="Type of noise schedule for diffusion training"
)
min_mask_ratio: float = Field(
default=0.1,
ge=0.0,
le=1.0,
description="Minimum masking ratio for diffusion noise schedule",
)
max_mask_ratio: float = Field(
default=0.9,
ge=0.0,
le=1.0,
description="Maximum masking ratio for diffusion noise schedule",
)
num_diffusion_steps: int = Field(
default=128, ge=1, description="Number of diffusion timesteps"
)
eps: float = Field(
default=1e-3,
ge=0.0,
le=1.0,
description="Epsilon value for minimum masking probability in forward process",
)
# Training config
importance_weighting: bool = Field(
default=True,
description="Apply importance weighting to loss based on masking probability",
)
mask_token_id: int | None = Field(
default=None,
description=(
"Token ID to use for masking. Unset by default; can use one of the "
"tokenizer's special tokens here."
),
)
mask_token_str: str | None = Field(
default=None,
description=(
"Token string to use as a mask. If `mask_token_id` is invalid or unset, "
"this token will be ensured to exist as an additional special token and "
"used. If absent, a default '<|diffusion_mask|>' will be added."
),
)
# Sample generation config
generate_samples: bool = Field(
default=True, description="Enable sample generation during training"
)
generation_interval: int = Field(
default=100, ge=1, description="Generate samples every N steps"
)
num_generation_samples: int = Field(
default=3, ge=1, description="Number of samples to generate each time"
)
generation_steps: int = Field(
default=128, ge=1, description="Number of diffusion steps for generation"
)
generation_temperature: float = Field(
default=0.0,
ge=0.0,
description="Temperature for generation sampling (0.0 = deterministic)",
)
generation_max_length: int = Field(
default=100, ge=1, description="Maximum sequence length for generation"
)
@model_validator(mode="after")
def _validate_mask_ratios(self) -> "DiffusionConfig":
if self.min_mask_ratio > self.max_mask_ratio:
raise ValueError("min_mask_ratio must be ≤ max_mask_ratio")
return self
class DiffusionArgs(BaseModel):
"""Plugin entry that exposes the nested `diffusion` block to the core config."""
diffusion: DiffusionConfig = Field(
default_factory=DiffusionConfig,
description="Diffusion training configuration. Only nested block is supported.",
)

View File

@@ -0,0 +1,174 @@
"""Callbacks for diffusion training."""
import logging
import sys
import wandb
from colorama import Fore, Style
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from .generation import generate_samples
# Simpler logger for more readable sample generation
logger = logging.getLogger(__name__)
if not logger.handlers:
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(handler)
logger.propagate = False
logger.setLevel(logging.INFO)
class DiffusionGenerationCallback(TrainerCallback):
"""Callback for generating samples during diffusion training."""
def __init__(self, trainer):
self.trainer = trainer
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Generate samples at specified intervals."""
if (
state.global_step > 0
and state.global_step % self.trainer.cfg.diffusion.generation_interval == 0
):
if not self.trainer.state.is_world_process_zero:
return
# Use eval dataloader if available, otherwise use train dataloader
dataloader = None
try:
if getattr(self.trainer, "eval_dataset", None) is not None:
dataloader = self.trainer.get_eval_dataloader()
except Exception:
dataloader = None
if dataloader is None:
dataloader = self.trainer.get_train_dataloader()
# Generate samples
diffusion_cfg = self.trainer.cfg.diffusion
samples = generate_samples(
model=self.trainer.model,
tokenizer=self.trainer.processing_class,
dataloader=dataloader,
num_generation_samples=diffusion_cfg.num_generation_samples,
max_length=diffusion_cfg.generation_max_length,
num_diffusion_steps=diffusion_cfg.generation_steps,
temperature=diffusion_cfg.generation_temperature,
mask_token_id=diffusion_cfg.mask_token_id,
)
# Log samples
self._log_samples(samples, state.global_step)
def _log_samples(self, samples: list, step: int):
"""Log generated samples."""
if not samples:
return
logger.info("=" * 60)
logger.info("GENERATED SAMPLES")
logger.info("=" * 60)
for i, sample_data in enumerate(samples, 1):
original = sample_data["original"]
masked = sample_data["masked"]
generated = sample_data["generated"]
mask_ratio = sample_data["mask_ratio"]
masked_tokens = sample_data["masked_tokens"]
total_tokens = sample_data["total_tokens"]
logger.info(f"\nSample {i}:")
logger.info(f"\tOriginal ({total_tokens} tokens): {original}")
logger.info(
f"\tMasked ({masked_tokens}/{total_tokens} tokens, "
f"{mask_ratio:.1%}): {masked}"
)
try:
gen_ids = sample_data.get("generated_ids")
orig_ids = sample_data.get("orig_ids")
masked_positions = set(sample_data.get("masked_positions") or [])
if isinstance(gen_ids, list) and isinstance(orig_ids, list):
styles: list[str] = []
for i, tid in enumerate(gen_ids):
if i in masked_positions:
if i < len(orig_ids) and tid == orig_ids[i]:
styles.append("green")
elif i < len(orig_ids):
styles.append("red")
else:
styles.append("normal")
else:
same = i < len(orig_ids) and tid == orig_ids[i]
styles.append("dim" if same else "normal")
spans: list[tuple[str, int, int]] = []
if gen_ids:
cur = styles[0]
start = 0
for i in range(1, len(gen_ids)):
s = styles[i]
if s != cur:
spans.append((cur, start, i))
cur, start = s, i
spans.append((cur, start, len(gen_ids)))
parts = []
for style_name, a, b in spans:
chunk_text = self.trainer.processing_class.decode(
gen_ids[a:b], skip_special_tokens=False
)
if style_name == "green":
parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL)
elif style_name == "red":
parts.append(Fore.RED + chunk_text + Style.RESET_ALL)
else:
if style_name == "dim":
parts.append(Style.DIM + chunk_text + Style.RESET_ALL)
else:
parts.append(chunk_text)
logger.info("\tGenerated:\n%s", "".join(parts))
else:
logger.info(f"\tGenerated: {generated}")
except Exception:
logger.info(f"\tGenerated: {generated}")
logger.info("=" * 60)
if self.trainer.cfg.use_wandb:
if wandb.run is not None:
wandb.log(
{
"generated_samples": wandb.Table(
columns=[
"step",
"original",
"masked",
"generated",
"mask_ratio",
"masked_tokens",
"total_tokens",
],
data=[
[
step,
sample["original"],
sample["masked"],
sample["generated"],
f"{sample['mask_ratio']:.1%}",
sample["masked_tokens"],
sample["total_tokens"],
]
for sample in samples
],
)
},
step=step,
)

View File

@@ -0,0 +1,409 @@
"""Sample generation utilities for diffusion training."""
import re
from typing import Any, List, Literal, Optional
import torch
from axolotl.utils.logging import get_logger
from .utils import create_bidirectional_attention_mask
LOG = get_logger(__name__)
def generate_samples(
model: torch.nn.Module,
tokenizer: Any,
dataloader: Optional[Any] = None,
num_generation_samples: int = 3,
max_length: int = 100,
num_diffusion_steps: int = 128,
temperature: float = 0.0,
mask_token_id: int = 32000,
mode: Literal["random", "completion"] = "random",
completion_tokens: int = 0,
target_mask_ratio: Optional[float] = None,
) -> List[dict]:
"""
Generate text samples using the diffusion model by randomly masking sequences from
the given dataset and running the reverse diffusion process.
Args:
model: The wrapped or unwrapped model
tokenizer: Tokenizer for encoding/decoding
dataloader: Validation dataloader (for sampling sequences)
num_generation_samples: Number of samples to generate
max_length: Maximum length of sequences to use
num_diffusion_steps: Number of diffusion steps for generation
temperature: Temperature for sampling (0.0 = deterministic)
mask_token_id: Token ID used for masking
Returns:
List of dictionaries with original text, masked text, and generated text
"""
if dataloader is None:
LOG.warning("No validation dataloader provided, cannot generate samples")
return []
unwrapped_model = model.module if hasattr(model, "module") else model
training = unwrapped_model.training
unwrapped_model.eval()
# Resolve device robustly (some modules don't expose `.device`)
device = getattr(unwrapped_model, "device", None)
if device is None:
try:
device = next(unwrapped_model.parameters()).device
except StopIteration:
device = torch.device("cpu")
generations = []
# Sample sequences from validation dataset
sampled_sequences = _sample_sequences_from_dataloader(
dataloader, num_generation_samples, max_length, device
)
LOG.info(f"Sampled {len(sampled_sequences)} sequences from validation dataset")
# Generate samples using reverse diffusion process
with torch.no_grad():
for sample in sampled_sequences:
if isinstance(sample, dict):
original_sequence = sample.get("input_ids")
labels_seq = sample.get("labels")
attn_seq = sample.get("attention_mask")
else:
original_sequence = sample
labels_seq = None
attn_seq = None
generation_result = generate(
unwrapped_model,
tokenizer,
original_sequence,
num_diffusion_steps,
temperature,
mask_token_id,
mode=mode,
completion_tokens=completion_tokens,
target_mask_ratio=target_mask_ratio,
labels=labels_seq,
attention_mask=attn_seq,
)
generations.append(generation_result)
# Restore prior training state
if training:
unwrapped_model.train()
else:
unwrapped_model.eval()
return generations
def _sample_sequences_from_dataloader(
dataloader: Any, num_samples: int, max_length: int, device: torch.device
) -> List[Any]:
"""Sample sequences from validation dataloader."""
sampled_sequences: list[dict[str, torch.Tensor] | torch.Tensor] = []
sample_count = 0
# Skip a random number of batches (we could be more clever about this)
skip_batches = torch.randint(0, 10, (1,)).item()
batch_count = 0
for batch in dataloader:
# Skip some batches for variety
if batch_count < skip_batches:
batch_count += 1
continue
if sample_count >= num_samples:
break
batch_count += 1
input_ids = batch["input_ids"]
attention_mask = batch.get("attention_mask")
labels = batch.get("labels")
# Randomly sample from sequences in this batch
batch_indices = torch.randperm(input_ids.size(0)).tolist()
for i in batch_indices:
if sample_count >= num_samples:
break
# Get actual sequence length (non-padded)
if attention_mask is not None:
seq_len = attention_mask[i].sum().item()
else:
seq_len = input_ids.size(1)
if seq_len < 10:
continue
# Determine truncation length
max_total = min(seq_len, max_length)
if labels is not None:
labels_i = labels[i][:seq_len]
answer_mask = labels_i != -100
if not answer_mask.any():
# No answer tokens; skip for SFT masking
continue
first_ans_idx = int(
torch.nonzero(answer_mask, as_tuple=False)[0].item()
)
prompt_len = first_ans_idx
if prompt_len >= max_total:
# Prompt alone reaches cap; cannot include any answer
continue
remaining_answer = int(answer_mask[prompt_len:].sum().item())
allowed_answer = max_total - prompt_len
take_answer = min(remaining_answer, allowed_answer)
if take_answer <= 0:
continue
actual_length = prompt_len + take_answer
else:
actual_length = max_total
# Extract the (possibly truncated) sequence
sequence = input_ids[i][:actual_length].unsqueeze(0).to(device)
attn_seq = (
attention_mask[i][:actual_length].unsqueeze(0).to(device)
if attention_mask is not None
else None
)
if labels is not None:
labels_seq = labels[i][:actual_length].unsqueeze(0).to(device)
sampled_sequences.append(
{
"input_ids": sequence,
"labels": labels_seq,
"attention_mask": attn_seq,
}
)
else:
if attn_seq is not None:
sampled_sequences.append(
{"input_ids": sequence, "attention_mask": attn_seq}
)
else:
sampled_sequences.append(sequence)
sample_count += 1
return sampled_sequences
def generate(
model: torch.nn.Module,
tokenizer: Any,
original_sequence: torch.Tensor,
num_diffusion_steps: int,
temperature: float,
mask_token_id: int,
*,
mode: Literal["random", "completion"] = "random",
completion_tokens: int = 0,
target_mask_ratio: Optional[float] = None,
labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> dict:
"""Generate a single sample using reverse diffusion."""
# Get original text for comparison
original_text = tokenizer.decode(
original_sequence[0].cpu(), skip_special_tokens=True
)
# Build masked sequence
if (
labels is not None
and labels.numel() > 0
and (labels == -100).any()
and (labels != -100).any()
):
# SFT case: completely mask all answer tokens (labels != -100)
total_tokens = original_sequence.size(1)
masked_indices = (labels != -100).to(dtype=torch.bool)
masked_sequence = original_sequence.clone()
masked_sequence[masked_indices] = mask_token_id
masked_tokens = int(masked_indices.sum().item())
mask_ratio = masked_tokens / max(int(total_tokens), 1)
elif mode == "completion" and completion_tokens > 0:
# Append mask tokens to the right for completion
total_tokens = original_sequence.size(1) + int(completion_tokens)
masked_indices = torch.zeros(
1, total_tokens, dtype=torch.bool, device=original_sequence.device
)
masked_indices[0, -int(completion_tokens) :] = True
append = torch.full(
(1, int(completion_tokens)), mask_token_id, device=original_sequence.device
)
masked_sequence = torch.cat([original_sequence, append], dim=1)
masked_tokens = int(completion_tokens)
mask_ratio = masked_tokens / total_tokens
else:
# Apply random masking with optional fixed ratio
total_tokens = original_sequence.size(1)
if target_mask_ratio is None:
min_ratio, max_ratio = 0.1, 0.7
target_mask_ratio = (
torch.rand(1).item() * (max_ratio - min_ratio) + min_ratio
)
target_masked_tokens = max(1, int(total_tokens * float(target_mask_ratio)))
# Create random mask indices
mask_positions = torch.randperm(total_tokens)[:target_masked_tokens]
masked_indices = torch.zeros(
1, total_tokens, dtype=torch.bool, device=original_sequence.device
)
masked_indices[0, mask_positions] = True
# Create masked sequence
masked_sequence = original_sequence.clone()
masked_sequence[masked_indices] = mask_token_id
# Calculate actual mask ratio
masked_tokens = masked_indices.sum().item()
mask_ratio = masked_tokens / total_tokens
# Get masked text for comparison
masked_text = tokenizer.decode(masked_sequence[0].cpu(), skip_special_tokens=False)
masked_text = _clean_masked_text(masked_text, tokenizer, mask_token_id)
# Run reverse diffusion process
sequence = masked_sequence.clone()
attention_mask = create_bidirectional_attention_mask(
sequence, attention_mask, sample_packing=attention_mask is not None
)
for step in range(num_diffusion_steps):
sequence = _diffusion_step(
model,
sequence,
step,
num_diffusion_steps,
temperature,
mask_token_id,
attention_mask,
)
generated_text = tokenizer.decode(sequence[0].cpu(), skip_special_tokens=True)
# Collect diagnostic info
final_ids = sequence[0].detach().cpu().tolist()
orig_ids_for_render = original_sequence[0].detach().cpu().tolist()
if masked_indices is not None:
masked_positions = (
torch.where(masked_indices[0])[0].detach().cpu().tolist()
if masked_indices.ndim == 2
else []
)
else:
masked_positions = []
result = {
"original": original_text,
"masked": masked_text,
"generated": generated_text,
"mask_ratio": mask_ratio,
"masked_tokens": masked_tokens,
"total_tokens": total_tokens,
"generated_ids": final_ids,
"masked_positions": masked_positions,
"orig_ids": orig_ids_for_render,
"formatted": (
f"Original: '{original_text}' → Masked: '{masked_text}' "
f"({mask_ratio:.1%}) → Generated: '{generated_text}'"
),
}
return result
def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str:
"""Clean up masked text for display."""
mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False)
cleaned = masked_text.replace(mask_token_repr, "[MASK]")
# Remove literal special token strings
if hasattr(tokenizer, "special_tokens_map"):
for token_value in tokenizer.special_tokens_map.values():
if token_value and isinstance(token_value, str):
cleaned = cleaned.replace(token_value, "")
# Normalize whitespace but preserve newlines
cleaned = cleaned.replace("\r\n", "\n").replace("\r", "\n")
cleaned = re.sub(r"[ \t]+", " ", cleaned)
cleaned = "\n".join(line.rstrip() for line in cleaned.split("\n")).strip()
return cleaned
def _diffusion_step(
model: torch.nn.Module,
sequence: torch.Tensor,
step: int,
num_diffusion_steps: int,
temperature: float,
mask_token_id: int,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Perform a single diffusion step with remasking."""
# Only process if there are masked tokens remaining
current_mask = sequence == mask_token_id
if not current_mask.any():
return sequence
# Create or use provided attention mask
if attention_mask is None:
batch_size, seq_len = sequence.shape
attention_mask = torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=sequence.device
)
# Forward pass
outputs = model(input_ids=sequence, attention_mask=attention_mask)
logits = outputs.logits
# Only sample at currently masked positions
if current_mask.any():
masked_logits = logits[current_mask]
# Apply temperature scaling
if temperature > 0:
scaled_logits = masked_logits / temperature
else:
scaled_logits = masked_logits
# Suppress mask token in outputs
scaled_logits[:, mask_token_id] = -float("inf")
if temperature > 0:
# Add Gumbel noise for sampling
gumbel_noise = -torch.log(
-torch.log(torch.rand_like(scaled_logits, dtype=torch.float32))
)
gumbel_logits = scaled_logits + gumbel_noise
predicted_tokens = torch.argmax(gumbel_logits, dim=-1)
else:
predicted_tokens = torch.argmax(scaled_logits, dim=-1)
# Calculate probabilities for confidence scoring
probs = torch.softmax(scaled_logits, dim=-1)
predicted_token_probs = probs[range(len(predicted_tokens)), predicted_tokens]
# Determine how many tokens to unmask this step
remaining_masked = current_mask.sum().item()
if step == num_diffusion_steps - 1:
num_to_unmask = remaining_masked
else:
unmask_ratio = 1.0 / (num_diffusion_steps - step)
num_to_unmask = max(1, int(remaining_masked * unmask_ratio))
# Select highest confidence predictions to unmask
if num_to_unmask >= remaining_masked:
sequence[current_mask] = predicted_tokens
else:
_, top_indices = predicted_token_probs.topk(num_to_unmask)
mask_positions = torch.where(current_mask)[1]
positions_to_unmask = mask_positions[top_indices]
sequence[0, positions_to_unmask] = predicted_tokens[top_indices]
return sequence

View File

@@ -0,0 +1,41 @@
"""Diffusion LM training plugin for Axolotl."""
from peft import PeftModel
from transformers import PreTrainedModel
from axolotl.integrations.base import BasePlugin
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .trainer import DiffusionTrainer
LOG = get_logger(__name__)
class DiffusionPlugin(BasePlugin):
"""
Plugin for diffusion language model training.
This plugin enables diffusion-based training using the LLaDA approach, which uses
random masking and bidirectional attention to train language models.
"""
def __init__(self):
super().__init__()
self.cfg = None
def get_input_args(self) -> str:
"""Returns the pydantic model for LLaDA plugin arguments."""
return "axolotl.integrations.diffusion.DiffusionArgs"
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Perform actions after model is loaded."""
self.cfg = cfg
def get_trainer_cls(self, cfg: DictDefault) -> type[DiffusionTrainer] | None:
"""Return custom trainer class for diffusion training."""
return DiffusionTrainer
def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer):
"""Configure trainer after creation."""
trainer.set_config(cfg)

View File

@@ -0,0 +1,301 @@
"""Custom trainer for diffusion LM training."""
from typing import Any, Literal
import torch
import torch.nn.functional as F
from torch import nn
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .callbacks import DiffusionGenerationCallback
from .utils import create_bidirectional_attention_mask
LOG = get_logger(__name__)
class DiffusionTrainer(AxolotlTrainer):
"""Custom trainer for diffusion LM training that overrides loss computation."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cfg = None
self._special_token_ids = None
def set_config(self, config: DictDefault):
"""Set config for diffusion training."""
self.cfg = config
self._cache_special_token_ids()
self._resolve_mask_token_id()
token_id = int(getattr(self.cfg.diffusion, "mask_token_id", 0))
LOG.info(f"Diffusion: using mask_token_id={token_id}")
if getattr(config.diffusion, "generate_samples", True):
generation_callback = DiffusionGenerationCallback(self)
self.add_callback(generation_callback)
def _resolve_mask_token_id(self) -> None:
"""Ensure mask_token_id is valid for the current tokenizer."""
from .utils import resolve_mask_token_id
tokenizer = getattr(self, "processing_class", None)
if tokenizer is None:
return
mid = resolve_mask_token_id(
tokenizer,
self.cfg,
allow_add=True,
model=getattr(self, "model", None),
)
try:
self.cfg.diffusion.mask_token_id = int(mid)
except Exception:
pass
def compute_loss(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor],
return_outputs: bool = False,
num_items_in_batch: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Override compute_loss to use diffusion loss."""
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask")
labels = inputs.get("labels")
if input_ids is None:
raise ValueError("input_ids is required for diffusion training")
loss, outputs = self._compute_diffusion_loss(
model, input_ids, attention_mask, labels
)
if return_outputs:
return loss, outputs
return loss
def _cache_special_token_ids(self):
"""Cache special token IDs to avoid repeated tokenizer access."""
if self.processing_class is None:
self._special_token_ids = set()
return
tokenizer = self.processing_class
special_tokens = set()
if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
special_tokens.add(tokenizer.bos_token_id)
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
special_tokens.add(tokenizer.eos_token_id)
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
special_tokens.add(tokenizer.pad_token_id)
self._special_token_ids = special_tokens
def _forward_process(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
eps: float = 1e-3,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward noising process. A timestep is sampled along the process, and tokens are
masked with probability determined by the configured noise schedule.
Args:
input_ids: Input token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
eps: Small epsilon value for minimum masking probability.
Returns:
noisy_batch: Input with some tokens masked.
masked_indices: Boolean mask indicating which tokens were masked.
p_mask: Masking probabilities for each token [batch_size, seq_len].
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Sample random timesteps for each sample in batch
t = torch.rand(batch_size, device=device)
p_mask = (1 - eps) * t + eps # [batch_size]
p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
# Don't mask padding tokens if attention_mask is provided
if attention_mask is not None:
valid_mask = attention_mask.bool()
p_mask = p_mask * valid_mask.float()
# Create mask to exclude special tokens
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
if self._special_token_ids:
for token_id in self._special_token_ids:
special_token_mask |= input_ids == token_id
# Create random mask based on p_mask
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
masked_indices = masked_indices & ~special_token_mask
if attention_mask is not None:
masked_indices = masked_indices & attention_mask.bool()
# For SFT data, only mask answer tokens
if labels is not None:
answer_mask = labels != -100
masked_indices = masked_indices & answer_mask
# Create masked input
mask_token_id = int(self.cfg.diffusion.mask_token_id)
mask_value = torch.full_like(input_ids, mask_token_id)
noisy_batch = torch.where(masked_indices, mask_value, input_ids)
return noisy_batch, masked_indices, p_mask
def _compute_diffusion_loss(
self,
model: nn.Module,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | Any]:
"""
Compute diffusion loss.
Args:
model: The model to compute loss for.
input_ids: Ground truth token ids [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Labels for SFT training [batch_size, seq_len].
Returns:
loss: Cross-entropy loss.
metrics: Dictionary of metrics.
"""
# Short-circuit empty sequences
if input_ids is None or input_ids.numel() == 0 or input_ids.shape[1] == 0:
zero = torch.tensor(
0.0,
device=(input_ids.device if input_ids is not None else None),
requires_grad=True,
)
return zero, {}
# If an attention_mask is provided and all positions are padding for every
# sample in this batch, skip the step.
if attention_mask is not None:
if attention_mask.dim() == 2 and (attention_mask.sum(dim=1) == 0).all():
zero = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
return zero, {}
# Apply forward process
noisy_batch, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self.cfg.diffusion.eps
)
# Create bidirectional attention mask
bidirectional_mask = create_bidirectional_attention_mask(
input_ids, attention_mask, sample_packing=self.cfg.sample_packing
)
# Forward pass
outputs = model(
input_ids=noisy_batch.long(),
attention_mask=bidirectional_mask,
)
logits = outputs.logits
if masked_indices.sum() > 0:
valid_indices = torch.where(masked_indices)
batch_indices, seq_indices = valid_indices
masked_logits = logits[batch_indices, seq_indices]
masked_targets = input_ids[batch_indices, seq_indices]
masked_p_mask = p_mask[batch_indices, seq_indices]
# Compute cross-entropy loss without reduction
token_loss = F.cross_entropy(
masked_logits.float(), masked_targets, reduction="none"
)
if self.cfg.diffusion.importance_weighting:
masked_p_mask = masked_p_mask.float()
weighted_loss = token_loss / masked_p_mask
else:
weighted_loss = token_loss
if labels is not None:
# For SFT data: normalize by answer token count per sample
answer_mask = labels != -100
answer_lengths = answer_mask.sum(dim=1).float() # [batch_size]
# Get batch indices for masked tokens
masked_batch_indices = batch_indices
# Sum losses per sample and divide by answer length
batch_size = input_ids.shape[0]
loss_per_sample = torch.zeros(batch_size, device=input_ids.device)
for i in range(batch_size):
sample_mask = masked_batch_indices == i
if sample_mask.sum() > 0:
sample_loss = weighted_loss[sample_mask].sum()
denom = answer_lengths[i].clamp(min=1.0)
loss_per_sample[i] = sample_loss / denom
loss = loss_per_sample.mean()
else:
# Non-SFT: when importance weighting is enabled, use unbiased estimator
# (sum(loss/p) / total_tokens). Otherwise, average over masked tokens
# for stable scaling across varying mask ratios.
if self.cfg.diffusion.importance_weighting:
loss = weighted_loss.sum() / (
input_ids.shape[0] * input_ids.shape[1]
)
else:
loss = weighted_loss.mean()
ce_loss = token_loss.mean()
# Compute accuracy on masked tokens
with torch.no_grad():
pred_tokens = masked_logits.argmax(dim=-1)
accuracy = (pred_tokens == masked_targets).float().mean()
else:
loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
accuracy = torch.tensor(0.0, device=input_ids.device)
ce_loss = torch.tensor(0.0, device=input_ids.device)
masked_p_mask = torch.tensor(1.0, device=input_ids.device)
avg_p_mask = (
p_mask[masked_indices].mean().item() if masked_indices.any() else 0.0
)
metrics = {
"loss": loss.item(),
"accuracy": accuracy.item(),
"mask_ratio": masked_indices.float().mean().item(),
"num_masked_tokens": (masked_indices.sum().item(), "sum"),
"avg_p_mask": avg_p_mask,
"ce_loss": ce_loss.item(),
}
# If doing SFT training, log answer-specific metrics
if self.cfg.datasets is not None:
with torch.no_grad():
answer_mask = labels != -100
answer_lengths = answer_mask.sum(dim=1).float() # type: ignore
total_answer_tokens = answer_mask.sum().item() # type: ignore
total_tokens = labels.numel() # type: ignore
metrics["answer_ratio"] = total_answer_tokens / max(total_tokens, 1)
metrics["avg_answer_length"] = answer_lengths.mean().item()
if self.cfg.diffusion.importance_weighting:
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
train_eval: Literal["train", "eval"] = "train" if model.training else "eval"
self.store_metrics(metrics, train_eval=train_eval)
return loss, outputs

View File

@@ -0,0 +1,159 @@
"""Shared utilities for diffusion integration."""
from __future__ import annotations
from typing import Any, Optional
import torch
from axolotl.utils.dict import DictDefault
def resolve_mask_token_id(
tokenizer: Any,
cfg: DictDefault,
*,
allow_add: bool,
model: Any | None = None,
default_token: str = "<|diffusion_mask|>",
) -> int:
"""Resolve mask token id. Training may add a new special token; inference won't."""
# Determine vocab size if available
vocab_size = None
if tokenizer is not None:
if hasattr(tokenizer, "vocab_size") and tokenizer.vocab_size is not None:
try:
vocab_size = int(tokenizer.vocab_size) # type: ignore[arg-type]
except Exception:
vocab_size = None
elif hasattr(tokenizer, "__len__"):
try:
vocab_size = int(len(tokenizer))
except Exception:
vocab_size = None
# Use explicit id from config if provided
diffusion_cfg = getattr(cfg, "diffusion", None)
# Fallback to top-level attr names only if nested missing (shouldn't happen)
cfg_id = (
getattr(diffusion_cfg, "mask_token_id", None)
if diffusion_cfg is not None
else getattr(cfg, "diffusion_mask_token_id", None)
)
if isinstance(cfg_id, int) and cfg_id >= 0:
if vocab_size is None or cfg_id < vocab_size:
return int(cfg_id)
def _existing_special_token_id(token_str: str | None) -> int | None:
"""Attempt to resolve an existing special token string to a real ID."""
if not token_str or not hasattr(tokenizer, "convert_tokens_to_ids"):
return None
try:
token_id = tokenizer.convert_tokens_to_ids(token_str)
except Exception:
return None
if not isinstance(token_id, int) or token_id < 0:
return None
# Ensure it's registered as special and not UNK, and within vocab
unk_id = getattr(tokenizer, "unk_token_id", None)
specials = set(getattr(tokenizer, "all_special_tokens", []) or [])
addl = set(getattr(tokenizer, "additional_special_tokens", []) or [])
is_special = token_str in specials or token_str in addl
in_vocab = vocab_size is None or token_id < vocab_size
if (
(unk_id is not None and token_id == unk_id)
or not is_special
or not in_vocab
):
return None
return token_id
# Try mask token string if provided
token_str = (
getattr(diffusion_cfg, "mask_token_str", None)
if diffusion_cfg is not None
else getattr(cfg, "diffusion_mask_token_str", None)
)
for candidate in (token_str, default_token):
token_id = _existing_special_token_id(candidate)
if isinstance(token_id, int):
try:
if diffusion_cfg is None:
cfg.diffusion_mask_token_id = int(token_id) # legacy fallback
else:
diffusion_cfg.mask_token_id = int(token_id)
except Exception:
pass
return int(token_id)
# Optionally add and return a dedicated special token during training
if allow_add and hasattr(tokenizer, "add_special_tokens"):
token_to_add = token_str or default_token
try:
tokenizer.add_special_tokens({"additional_special_tokens": [token_to_add]})
# Resize embeddings if possible
if (
model is not None
and hasattr(tokenizer, "__len__")
and hasattr(model, "resize_token_embeddings")
):
try:
model.resize_token_embeddings(len(tokenizer))
except Exception:
pass
new_id = tokenizer.convert_tokens_to_ids(token_to_add)
if isinstance(new_id, int) and new_id >= 0:
try:
if diffusion_cfg is None:
cfg.diffusion_mask_token_id = int(new_id) # legacy fallback
else:
diffusion_cfg.mask_token_id = int(new_id)
except Exception:
pass
return int(new_id)
except Exception:
pass
# Fallback to unk or 0 (do not update cfg)
fallback = getattr(tokenizer, "unk_token_id", 0) or 0
return int(fallback)
def create_bidirectional_attention_mask(
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
sample_packing: bool = False,
) -> torch.Tensor:
"""
Create bidirectional attention mask to override default causal masking.
Handles sample-packed sequences where different samples are identified
by different attention mask values.
Args:
input_ids: Input token ids [batch_size, seq_len]
attention_mask: Attention mask [batch_size, seq_len]
sample_packing: Whether sample packing is enabled
Returns:
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len]
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
if attention_mask is None or not sample_packing:
return torch.ones(
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
)
# Handle sample packing: tokens can only attend within their sample
mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1]
mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len]
# Tokens can attend to each other if they have the same non-zero sample ID
bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)
# Add head dimension: [batch_size, 1, seq_len, seq_len]
return bidirectional_mask.unsqueeze(1)

View File

@@ -0,0 +1,3 @@
from .backends import MOEBackend, get_moe_backend_name
__all__ = ["get_moe_backend_name", "MOEBackend"]

View File

@@ -0,0 +1,47 @@
import warnings
from enum import Enum
class MOEBackend(str, Enum):
AUTO = "auto"
TORCH_GROUPED = "torch_grouped"
NAIVE = "naive"
def _probe_torch_grouped() -> bool:
try:
import torch # noqa: F401
# Prefer a simple version check; exact APIs may vary across 2.8+.
ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
return ver >= (2, 8)
except Exception:
return False
def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
"""
Resolve the desired MoE backend using, in order of precedence:
- explicit preferred argument (e.g., from config)
- auto detection
"""
choice = (preferred or "auto").lower()
try:
selected = MOEBackend(choice)
except ValueError:
warnings.warn(
f"Unknown moe backend '{choice}', falling back to auto", stacklevel=2
)
selected = MOEBackend.AUTO
if selected == MOEBackend.AUTO:
if _probe_torch_grouped():
return MOEBackend.TORCH_GROUPED
return MOEBackend.NAIVE
if selected == MOEBackend.TORCH_GROUPED and not _probe_torch_grouped():
warnings.warn(
"torch_grouped requested but torch>=2.8 not detected; falling back to naive",
stacklevel=2,
)
return MOEBackend.NAIVE
return selected

View File

@@ -0,0 +1,371 @@
"""Minimal grouped GEMM fast path for MoE experts using PyTorch _grouped_mm."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
_LOGGER = logging.getLogger("axolotl.moe.grouped")
def available() -> bool:
try:
major, minor = map(int, torch.__version__.split("+")[0].split(".")[:2])
if (major, minor) < (2, 8):
return False
if not torch.cuda.is_available():
return False
sm, _ = torch.cuda.get_device_capability()
if sm < 9:
return False
return hasattr(torch.ops, "_grouped_mm")
except Exception:
return False
def _iter_expert_impls(
experts_module, visited: Optional[set[int]] = None
) -> List[torch.nn.Module]:
if visited is None:
visited = set()
module_id = id(experts_module)
if module_id in visited:
return []
visited.add(module_id)
impls: List[torch.nn.Module] = []
for exp in experts_module:
candidate = getattr(exp, "mlp", getattr(exp, "ffn", exp))
if hasattr(candidate, "gate_proj") and hasattr(candidate, "up_proj"):
impls.append(candidate)
continue
nested = getattr(candidate, "experts", None)
if nested is not None:
impls.extend(_iter_expert_impls(nested, visited))
continue
raise RuntimeError(
"torch_grouped: unable to resolve expert implementation for module"
)
return impls
@dataclass
class _GroupedWeightStorage:
pattern: str
gate: torch.Tensor
up: torch.Tensor
down: torch.Tensor
fused_gate_up: torch.Tensor
dtype: torch.dtype
device: torch.device
def _allocate_fused_gate_up(
num_experts: int,
gate_shape: torch.Size,
up_shape: torch.Size,
*,
device: torch.device,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if gate_shape[1] != up_shape[1]:
raise RuntimeError(
"torch_grouped: gate and up projections must share the hidden dimension"
)
fused = torch.empty(
(num_experts, gate_shape[0] + up_shape[0], gate_shape[1]),
device=device,
dtype=dtype,
)
gate_view = fused[:, : gate_shape[0]]
up_view = fused[:, gate_shape[0] : gate_shape[0] + up_shape[0]]
return fused, gate_view, up_view
def _ensure_grouped_weights(
experts_module, expert_impls: List[torch.nn.Module], sample_mod: torch.nn.Module
) -> _GroupedWeightStorage:
storage: Optional[_GroupedWeightStorage] = getattr(
experts_module, "_ax_grouped_storage", None
)
def _store(new_storage: _GroupedWeightStorage) -> _GroupedWeightStorage:
experts_module._ax_grouped_storage = new_storage
return new_storage
# Identify expert parameter layout
if (
hasattr(sample_mod, "w1")
and hasattr(sample_mod, "w3")
and hasattr(sample_mod, "w2")
):
pattern = "swi_glu"
num_experts = len(expert_impls)
w1_shape = sample_mod.w1.weight.shape
w3_shape = sample_mod.w3.weight.shape
w2_shape = sample_mod.w2.weight.shape
if (
storage is not None
and storage.pattern == pattern
and storage.dtype == sample_mod.w1.weight.dtype
and storage.device == sample_mod.w1.weight.device
and storage.gate.shape[1:] == w1_shape
):
return storage
fused, gate, up = _allocate_fused_gate_up(
num_experts,
w1_shape,
w3_shape,
device=sample_mod.w1.weight.device,
dtype=sample_mod.w1.weight.dtype,
)
down = torch.empty(
(num_experts, *w2_shape),
device=sample_mod.w2.weight.device,
dtype=sample_mod.w2.weight.dtype,
)
with torch.no_grad():
for idx, mod in enumerate(expert_impls):
gate[idx].copy_(mod.w1.weight.detach())
up[idx].copy_(mod.w3.weight.detach())
down[idx].copy_(mod.w2.weight.detach())
mod.w1.weight.detach_()
mod.w1.weight.set_(gate[idx])
mod.w3.weight.detach_()
mod.w3.weight.set_(up[idx])
mod.w2.weight.detach_()
mod.w2.weight.set_(down[idx])
return _store(
_GroupedWeightStorage(
pattern=pattern,
gate=gate,
up=up,
down=down,
fused_gate_up=fused,
dtype=gate.dtype,
device=gate.device,
)
)
if hasattr(sample_mod, "gate_up_proj") and hasattr(sample_mod, "down_proj"):
pattern = "fused_gate_up"
gate_weight = sample_mod.gate_up_proj.weight
down_weight = sample_mod.down_proj.weight
if (
storage is not None
and storage.pattern == pattern
and storage.dtype == gate_weight.dtype
and storage.device == gate_weight.device
and storage.gate.shape[1:]
== (gate_weight.shape[0] // 2, gate_weight.shape[1])
):
return storage
num_experts = len(expert_impls)
gate_full = torch.empty(
(num_experts, *gate_weight.shape),
device=gate_weight.device,
dtype=gate_weight.dtype,
)
down = torch.empty(
(num_experts, *down_weight.shape),
device=down_weight.device,
dtype=down_weight.dtype,
)
with torch.no_grad():
for idx, mod in enumerate(expert_impls):
gate_full[idx].copy_(mod.gate_up_proj.weight.detach())
down[idx].copy_(mod.down_proj.weight.detach())
mod.gate_up_proj.weight.detach_()
mod.gate_up_proj.weight.set_(gate_full[idx])
mod.down_proj.weight.detach_()
mod.down_proj.weight.set_(down[idx])
inter = gate_weight.shape[0] // 2
gate = gate_full[:, :inter]
up = gate_full[:, inter:]
return _store(
_GroupedWeightStorage(
pattern=pattern,
gate=gate,
up=up,
down=down,
fused_gate_up=gate_full,
dtype=gate.dtype,
device=gate.device,
)
)
if (
hasattr(sample_mod, "up_proj")
and hasattr(sample_mod, "gate_proj")
and hasattr(sample_mod, "down_proj")
):
pattern = "dual_proj"
up_weight = sample_mod.up_proj.weight
gate_weight = sample_mod.gate_proj.weight
down_weight = sample_mod.down_proj.weight
if (
storage is not None
and storage.pattern == pattern
and storage.dtype == sample_mod.up_proj.weight.dtype
and storage.device == sample_mod.up_proj.weight.device
and storage.gate.shape[1:] == gate_weight.shape
):
return storage
num_experts = len(expert_impls)
fused, gate, up = _allocate_fused_gate_up(
num_experts,
gate_weight.shape,
up_weight.shape,
device=gate_weight.device,
dtype=gate_weight.dtype,
)
down = torch.empty(
(num_experts, *down_weight.shape),
device=down_weight.device,
dtype=down_weight.dtype,
)
with torch.no_grad():
for idx, mod in enumerate(expert_impls):
gate[idx].copy_(mod.gate_proj.weight.detach())
up[idx].copy_(mod.up_proj.weight.detach())
down[idx].copy_(mod.down_proj.weight.detach())
mod.up_proj.weight.detach_()
mod.up_proj.weight.set_(up[idx])
mod.gate_proj.weight.detach_()
mod.gate_proj.weight.set_(gate[idx])
mod.down_proj.weight.detach_()
mod.down_proj.weight.set_(down[idx])
return _store(
_GroupedWeightStorage(
pattern=pattern,
gate=gate,
up=up,
down=down,
fused_gate_up=fused,
dtype=gate.dtype,
device=gate.device,
)
)
raise RuntimeError(
"torch_grouped: unsupported expert module layout for grouped weights"
)
def moe_ffn_forward_grouped(
hidden_states: torch.Tensor,
gate_linear: torch.nn.Linear,
experts_module,
top_k: int,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
if not available():
return None, None
bsz, seqlen, hdim = hidden_states.shape
tokens = bsz * seqlen
device = hidden_states.device
routing_dtype = gate_linear.weight.dtype
expert_dtype = hidden_states.dtype
if expert_dtype not in (torch.bfloat16, torch.float16):
_LOGGER.debug(
"torch_grouped: unsupported expert dtype %s; falling back to naive",
expert_dtype,
)
return None, None
parent_block = None
parent_ref = getattr(experts_module, "_ax_parent_block_ref", None)
if parent_ref is not None:
try:
parent_block = parent_ref()
except TypeError:
parent_block = None
expert_container = getattr(experts_module, "experts", experts_module)
expert_impls = _iter_expert_impls(expert_container)
sample_mod = expert_impls[0]
storage = _ensure_grouped_weights(expert_container, expert_impls, sample_mod)
w_gate = storage.gate
w_up = storage.up
w2 = storage.down
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
router_logits = gate_linear(x_flat.to(routing_dtype))
shared_out_flat: Optional[torch.Tensor] = None
shared_owner = parent_block if parent_block is not None else experts_module
if hasattr(shared_owner, "shared_expert"):
shared_expert = shared_owner.shared_expert
shared_out_flat = shared_expert(x_flat)
shared_out_flat = shared_out_flat.to(expert_dtype)
shared_gate = getattr(shared_owner, "shared_expert_gate", None)
if shared_gate is not None:
gate_input = shared_gate(x_flat.to(shared_gate.weight.dtype))
gate_vals = torch.sigmoid(gate_input)
shared_out_flat.mul_(gate_vals.to(expert_dtype))
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
flat_idx = topk_idx.view(-1)
num_experts = len(expert_impls)
if flat_idx.numel() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
sorted_experts, perm = torch.sort(flat_idx)
assignments = torch.bincount(sorted_experts, minlength=num_experts)
if assignments.sum() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
token_indices_sorted = torch.div(perm, top_k, rounding_mode="floor").contiguous()
scores_sorted = topk_weight.reshape(-1).index_select(0, perm)
gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim)
routed_input = torch.gather(x_flat, 0, gather_index)
counts_i32 = assignments.to(device=device, dtype=torch.int32)
offsets = torch.cumsum(counts_i32, dim=0).to(dtype=torch.int32)
mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype
routed_in = routed_input.to(mm_dtype)
w_gate_t = w_gate.transpose(-2, -1).to(mm_dtype)
w_up_t = w_up.transpose(-2, -1).to(mm_dtype)
w2_t = w2.transpose(-2, -1).to(mm_dtype)
routed_in = routed_in.contiguous()
w_gate_t = w_gate_t.contiguous()
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
torch.ops.aten.silu_(gate_out)
w_up_t = w_up_t.contiguous()
up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets)
gate_out.mul_(up_out)
gate_out = gate_out.contiguous()
w2_t = w2_t.contiguous()
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets).to(expert_dtype)
weights = scores_sorted.unsqueeze(-1).to(expert_dtype)
down_out.mul_(weights)
combined = torch.zeros_like(x_flat)
combined.scatter_add_(0, gather_index, down_out)
output = combined.view(bsz, seqlen, hdim)
if shared_out_flat is not None:
output = output + shared_out_flat.view(bsz, seqlen, hdim)
return output, router_logits

View File

@@ -14,6 +14,7 @@ from peft import (
PeftConfig,
PeftMixedModel,
PeftModel,
TaskType,
get_peft_model,
)
from transformers import PreTrainedModel
@@ -98,6 +99,17 @@ def load_lora(
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
if cfg.peft_layer_replication:
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
if cfg.peft_trainable_token_indices:
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
# Determine the correct PEFT task type
model_cls = type(model).__name__
if "SequenceClassification" in model_cls:
task_type = TaskType.SEQ_CLS
elif "TokenClassification" in model_cls:
task_type = TaskType.TOKEN_CLS
else:
task_type = TaskType.CAUSAL_LM
lora_config = LoraConfig(
r=cfg.lora_r,
@@ -110,7 +122,7 @@ def load_lora(
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
bias="none",
task_type="CAUSAL_LM",
task_type=task_type,
**lora_config_kwargs,
)

View File

@@ -673,6 +673,33 @@ class ModelLoader:
return hf_ds_cfg
def _load_model_from_config(self, model_loader_class=None) -> PreTrainedModel:
"""
Load model with random initialization using from_config.
Uses the selected loader when provided; otherwise falls back to the auto loader.
"""
loader = model_loader_class or self.auto_model_loader
if loader in [AutoModelForCausalLM, AutoModelForVision2Seq]:
model = loader.from_config(
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
)
else:
model = loader(config=self.model_config)
return model
def _load_model_from_pretrained(self, model_loader_class=None) -> PreTrainedModel:
"""Load model from pretrained weights."""
loader = model_loader_class or self.auto_model_loader
kwargs = {
"config": self.model_config,
"trust_remote_code": self.cfg.trust_remote_code or False,
**self.model_kwargs,
}
return loader.from_pretrained(self.base_model, **kwargs)
def _build_model(self) -> bool:
"""Load model, with load strategy depending on config."""
skip_move_to_device = False
@@ -687,7 +714,8 @@ class ModelLoader:
if self.is_fsdp_enabled:
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
skip_move_to_device = True
# Don't delete device_map for QLoRA + FSDP - it was set correctly in _set_device_map
# Don't delete device_map for QLoRA + FSDP - it was set correctly in
# _set_device_map
if (
"device_map" in self.model_kwargs
and not self.is_qlora_and_fsdp_enabled
@@ -716,6 +744,11 @@ class ModelLoader:
or self.cfg.qlora_sharded_model_loading
)
):
if self.cfg.reinit_weights:
LOG.warning(
"reinit_weights is not supported with sharded quantized loading. "
"Loading from pretrained weights instead."
)
quant_storage = self.cfg.torch_dtype
quantization_config = getattr(
self.model_config, "quantization_config", None
@@ -731,33 +764,12 @@ class ModelLoader:
quantization_config=quantization_config,
)
skip_move_to_device = True
elif (
self.model_config.model_type in ["llama", "llama4"]
and not self.cfg.trust_remote_code
and not self.cfg.gptq
):
# Please don't remove underscore binding without reading the fn docstring.
_ = self._configure_zero3_memory_efficient_loading()
# Load model with random initialization if specified
if self.cfg.random_init_weights:
# AutoModel classes support the from_config method
if self.auto_model_loader in [
AutoModelForCausalLM,
AutoModelForVision2Seq,
]:
self.model = self.auto_model_loader.from_config(
config=self.model_config,
)
else:
self.model = self.auto_model_loader(config=self.model_config)
else:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
**self.model_kwargs,
)
elif self.model_type == "MambaLMHeadModel":
if self.cfg.reinit_weights:
LOG.warning(
"reinit_weights is not supported with MambaLMHeadModel. "
"Loading from pretrained weights instead."
)
# FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss()
@@ -770,41 +782,27 @@ class ModelLoader:
self.base_model,
**self.model_kwargs,
)
elif (
self.model_type
and self.model_type != "AutoModelForCausalLM"
and not self.cfg.trust_remote_code
):
if self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else:
self.model = getattr(transformers, self.model_type).from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
elif self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else:
# Please don't remove underscore binding without reading the fn docstring.
# Please don't remove underscore binding without reading the fn docstring
_ = self._configure_zero3_memory_efficient_loading()
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
if (
self.model_type
and self.model_type != "AutoModelForCausalLM"
and not self.cfg.trust_remote_code
and not self.cfg.gptq
):
# Use model type from transformers
model_loader_class = getattr(transformers, self.model_type)
else:
# Use auto model loader (handles gptq and default cases)
model_loader_class = self.auto_model_loader
if self.cfg.reinit_weights:
self.model = self._load_model_from_config(model_loader_class)
else:
self.model = self._load_model_from_pretrained(model_loader_class)
if is_deepspeed_zero3_enabled():
skip_move_to_device = True

View File

@@ -4,6 +4,7 @@ Applies pre- and post-model load patches for various fixes and optimizations.
"""
import importlib.util
import os
from functools import cached_property
import addict
@@ -11,6 +12,7 @@ import transformers
from transformers import PretrainedConfig, PreTrainedModel
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.moe_grouped import apply_grouped_to_moe_blocks
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack,
@@ -56,6 +58,8 @@ class PatchManager:
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_model_specific_patches()
# Apply MoE grouped GEMM patches (cfg.moe_backend)
apply_grouped_to_moe_blocks(self.cfg)
self._apply_fp8_patches()
self._apply_flash_attention_peft_patches()
self._apply_gradient_checkpointing_patches()
@@ -66,6 +70,7 @@ class PatchManager:
self._apply_mistral_cross_entropy_patch()
self._apply_self_attention_lora_patch()
self._apply_fsdp2_bnb_patches()
self._apply_patch_deepspeed_zero3()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
@@ -78,13 +83,7 @@ class PatchManager:
patch_maybe_log_save_evaluate,
)
patch_fsdp2 = (
self.cfg.torch_compile
and self.cfg.fsdp_config
and self.cfg.fsdp_version == 2
)
patch_evaluation_loop(patch_fsdp2)
patch_evaluation_loop()
patch_maybe_log_save_evaluate()
def apply_post_model_load_patches(self, model: PreTrainedModel):
@@ -147,14 +146,12 @@ class PatchManager:
def _apply_flex_attention_patches(self):
"""Apply patches for flexible attention."""
if self.cfg.flex_attention:
# from axolotl.monkeypatch.attention.flex_attn import (
# patch_flex_make_mask,
# patch_flex_wrapper,
# )
#
# flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
# patch_flex_wrapper(**flex_attn_compile_kwargs)
# patch_flex_make_mask()
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_wrapper,
)
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
patch_flex_wrapper(**flex_attn_compile_kwargs)
if self.cfg.sample_packing:
from axolotl.core.attention.flex_block_mask import (
patch_create_causal_mask,
@@ -275,6 +272,7 @@ class PatchManager:
self.cfg.model_config_type,
model_name=self.cfg.base_model,
has_remote_code=has_remote_code,
cfg=self.cfg,
)
if self.cfg.sample_packing:
@@ -471,3 +469,17 @@ class PatchManager:
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
apply_lora_kernel_patches(model=model, cfg=self.cfg)
def _apply_patch_deepspeed_zero3(self):
try:
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.monkeypatch.deepspeed_utils import apply_deepspeed_patches
if self.cfg.activation_offloading is True and (
is_deepspeed_zero3_enabled()
or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3"
):
apply_deepspeed_patches()
except ImportError as e:
LOG.warning(f"DeepSpeed patches not applied: {e}")

View File

@@ -296,7 +296,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
)
tokenizer.chat_template = chat_template_string
else:
elif getattr(tokenizer, "chat_template", None) is None:
LOG.info(
"No Chat template selected. Consider adding a chat template for easier inference."
)

View File

@@ -1,10 +1,7 @@
"""
Common logging module for axolotl
"""
"""Common logging module for axolotl."""
import logging
import os
import sys
from logging import Formatter, Logger, LogRecord
from logging.config import dictConfig
from typing import Any, Dict
@@ -17,9 +14,9 @@ DEFAULT_LOG_LEVEL = "WARNING"
class AxolotlOrWarnErrorFilter(logging.Filter):
"""
Allows ANY WARNING or higher (unless overridden by LOG_LEVEL)
Allows axolotl.* at INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL)
Drops all other records (i.e. non-axolotl.INFO, DEBUG, etc. by default)
Allows ANY WARNING or higher (unless overridden by LOG_LEVEL). Allows axolotl.* at
INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL). Drops all other records
(i.e. non-axolotl.INFO, DEBUG, etc. by default).
"""
def __init__(self, **kwargs):
@@ -52,13 +49,12 @@ class AxolotlOrWarnErrorFilter(logging.Filter):
class AxolotlLogger(Logger):
"""A Logger that automatically rejects non-axolotl INFOs."""
"""Logger that applies filtering to non-axolotl loggers."""
def __init__(self, name: str, level: int = logging.NOTSET):
super().__init__(name, level)
# set global filter on the logger itself
self.addFilter(AxolotlOrWarnErrorFilter())
if not name.startswith("axolotl"):
self.addFilter(AxolotlOrWarnErrorFilter())
class ColorfulFormatter(Formatter):
@@ -74,6 +70,7 @@ class ColorfulFormatter(Formatter):
def format(self, record):
record.rank = int(os.getenv("LOCAL_RANK", "0"))
record.rank_fmt = f" [RANK:{record.rank}]" if record.rank != 0 else ""
log_message = super().format(record)
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
@@ -87,32 +84,54 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
},
"colorful": {
"()": ColorfulFormatter,
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s",
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d]%(rank_fmt)s %(message)s",
},
"concise": {
"format": "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s",
},
"concise_color": {
"()": ColorfulFormatter,
"format": "[%(asctime)s] [%(levelname)s] [%(name)s]%(rank_fmt)s %(message)s",
},
},
"filters": {
"ax_or_warn": {
"()": "axolotl.logging_config.AxolotlOrWarnErrorFilter",
},
},
"filters": {},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "simple",
"filters": [],
"stream": sys.stdout,
"formatter": "concise",
"filters": ["ax_or_warn"],
"stream": "ext://sys.stdout",
},
"color_console": {
"class": "logging.StreamHandler",
"formatter": "colorful",
"filters": [],
"stream": sys.stdout,
"formatter": "concise_color",
"filters": ["ax_or_warn"],
"stream": "ext://sys.stdout",
},
"ax_file_only": {
"class": "logging.StreamHandler",
"level": "DEBUG",
"formatter": "simple",
"stream": "ext://axolotl.utils.tee.file_only_stream",
},
"root_file_only": {
"class": "logging.StreamHandler",
"level": "DEBUG",
"formatter": "simple",
"stream": "ext://axolotl.utils.tee.file_only_stream",
},
},
# log level will be superseded by the AxolotlLogger
"root": {
"handlers": ["console"],
"level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL),
"handlers": ["console", "root_file_only"],
"level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL).upper(),
},
"loggers": {
"axolotl": {
"handlers": ["color_console"],
"handlers": ["color_console", "ax_file_only"],
"level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL).upper(),
"propagate": False,
},
@@ -123,9 +142,15 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
def configure_logging():
"""Configure with default logging"""
init() # Initialize colorama
dictConfig(DEFAULT_LOGGING_CONFIG)
logging.setLoggerClass(AxolotlLogger)
# set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set
# Route Python warnings through logging so they reach file handlers
logging.captureWarnings(True)
# Set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set
if "ACCELERATE_LOG_LEVEL" not in os.environ:
os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL)
os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv(
"LOG_LEVEL", DEFAULT_LOG_LEVEL
).upper()

View File

@@ -160,9 +160,11 @@ def get_state_dict(self, model, unwrap=True):
state_dict[param_name] = param.cpu()
torch.distributed.barrier()
elif self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.distributed.fsdp import (
FullStateDictConfig,
FullyShardedDataParallel as FSDP,
StateDictType,
)
full_state_dict_config = FullStateDictConfig(
offload_to_cpu=True, rank0_only=True

View File

@@ -1,10 +1,11 @@
"""Flex attention monkey patch"""
import sys
from typing import Optional, Tuple, Union
import torch
import transformers
from packaging import version
from transformers.utils.import_utils import _torch_version, is_torch_less_or_equal
from axolotl.utils.logging import get_logger
@@ -46,19 +47,33 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
"""
self.training = None
if not self._is_flex_compiled or training != self.training:
self.training = training
if is_torch_less_or_equal("2.5.1"):
self._compiled_flex_attention = torch.compile(
flex_attention, dynamic=False
)
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
# see https://github.com/pytorch/pytorch/issues/146260 for training
self.training = training
LOG.info(
"Compiling flex attention with kwargs: %s. This may take a while...",
flex_attn_compile_kwargs,
)
self._compiled_flex_attention = torch.compile(
flex_attention,
**flex_attn_compile_kwargs,
)
LOG.info("Flex attention compiled successfully.")
elif version.parse(_torch_version).base_version == "2.6.0" and training:
self._compiled_flex_attention = torch.compile(
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
)
# Fallback, usually the most recent torch 2.7.x+ versions
else:
LOG.info(
"Compiling flex attention with kwargs: %s. This may take a while...",
flex_attn_compile_kwargs,
main_process_only=True,
)
self._compiled_flex_attention = torch.compile(
flex_attention,
**flex_attn_compile_kwargs,
)
LOG.info(
"Flex attention compiled successfully.", main_process_only=True
)
self._is_flex_compiled = True
def __call__(self):
@@ -68,139 +83,3 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
sys.modules[
"transformers.integrations.flex_attention"
].WrappedFlexAttention = WrappedFlexAttention
def patch_flex_make_mask():
is_torch_2_6 = torch.__version__.startswith("2.6")
if not is_torch_2_6:
return
from torch.nn.attention.flex_attention import (
_DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size,
)
from torch.nn.attention.flex_attention import (
BlockMask,
)
from torch.nn.attention.flex_attention import (
create_block_mask as create_block_causal_mask_flex,
)
Offset = Union[torch.Tensor, int]
def patched_make_flex_block_causal_mask(
attention_mask_2d: torch.Tensor,
attention_chunk_size: Optional[int] = None,
query_length=None,
key_length=None,
offsets: Optional[Tuple[Offset, Offset]] = None,
) -> "BlockMask":
"""
Create a block causal document mask for a batch of sequences, both packed and unpacked.
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
The resultant BlockMask is a compressed representation of the full block causal
mask. BlockMask is essential for performant computation of flex attention.
See: https://pytorch.org/blog/flexattention/
Args:
attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences
of shape (batch_size, total_seq_len). e.g.
For unpacked sequence:
[[1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0]]
For packed sequence:
[[1, 1, 1, 2, 2, 2, 0],
[1, 1, 2, 2, 2, 3, 3]]
Returns:
BlockMask
"""
batch_size, total_seq_len = attention_mask_2d.shape
if not key_length:
key_length = total_seq_len
if not query_length:
query_length = total_seq_len
attention_mask_2d = torch.nn.functional.pad(
attention_mask_2d,
value=0,
pad=(0, abs(total_seq_len - max(key_length, flex_default_block_size))),
)
device = attention_mask_2d.device
document_ids = attention_mask_2d.clone()
if attention_chunk_size is not None:
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
chunk_idxs = (document_ids.clone().fill_(1).cumsum(-1) - 1) // (
attention_chunk_size
)
# Instead of passing a tensor mask, flex attention requires a mask_mod function
# that determines which elements of QK^T should be included in the attention
# computation prior to the softmax. For sample packing, we need both the
# logic for both causal mask and document mask. See PyTorch's official
# blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask = q_idx >= kv_idx # not valid when decoding
document_mask = (
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
)
padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
final_mask = causal_mask & padding_mask & document_mask
return final_mask
def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"""
Combines the chunk mask with the causal mask for chunked attention.
"""
chunk_mask = chunk_idxs[batch_idx, q_idx] == chunk_idxs[batch_idx, kv_idx]
causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx)
return chunk_mask & causal_doc_mask
mask_mod_maybe_combined = (
causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod
)
if offsets is not None:
q_offset = offsets[0]
kv_offset = offsets[1]
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
offset_q = q_idx + q_offset
offset_kv = kv_idx + kv_offset
return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv)
else:
mask_mod = mask_mod_maybe_combined
return create_block_causal_mask_flex(
mask_mod=mask_mod,
B=batch_size,
H=None, # attention head
Q_LEN=query_length,
KV_LEN=key_length,
device=device,
_compile=True,
)
for n in tuple(sys.modules):
if ".modeling_" in n:
if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
sys.modules[
n
].make_flex_block_causal_mask = patched_make_flex_block_causal_mask
sys.modules[
n
].make_flex_block_causal_mask = patched_make_flex_block_causal_mask
transformers.integrations.flex_attention.make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask
)

View File

@@ -0,0 +1,67 @@
import importlib
import importlib.util
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def patch_checkpoint_wrapper_setattr():
"""
Patch CheckpointWrapper to properly forward DeepSpeed attributes to wrapped modules.
This fixes the issue where CheckpointWrapper doesn't forward ds_* attributes
(like ds_grads_remaining) to the actual wrapped module, causing DeepSpeed
ZeRO-3 to fail when gradient checkpointing is enabled.
This issue occurs specifically with:
- QLoRA + DeepSpeed ZeRO-3
- gradient_checkpointing: true
- activation_offloading: true
References:
- https://github.com/deepspeedai/DeepSpeed/issues/7203
- https://github.com/deepspeedai/DeepSpeed/blob/38d1a9eb64c9e01e32eccc50b25ba18925287441/deepspeed/runtime/zero/parameter_offload.py#L424-L458
- https://github.com/axolotl-ai-cloud/axolotl/pull/3102
"""
try:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointWrapper,
)
# Check if already patched
if hasattr(CheckpointWrapper, "_axolotl_setattr_patched"):
LOG.debug("CheckpointWrapper already patched")
return
original_setattr = CheckpointWrapper.__setattr__
def new_setattr(self, name: str, value) -> None:
if name.startswith("ds_") and hasattr(self, "_checkpoint_wrapped_module"):
setattr(self._checkpoint_wrapped_module, name, value)
LOG.debug(
f"Forwarded {name} to wrapped module {type(self._checkpoint_wrapped_module).__name__}"
)
else:
original_setattr(self, name, value)
CheckpointWrapper.__setattr__ = new_setattr
CheckpointWrapper._axolotl_setattr_patched = True
LOG.info("CheckpointWrapper patched to forward DeepSpeed attributes")
except ImportError as e:
LOG.debug(f"CheckpointWrapper not available: {e}")
except Exception as e:
LOG.warning(f"Failed to patch CheckpointWrapper: {e}")
def apply_deepspeed_patches():
"""
Apply DeepSpeed-related patches
"""
if importlib.util.find_spec("deepspeed") is not None:
patch_checkpoint_wrapper_setattr()
else:
LOG.debug("DeepSpeed not available, skipping patches")

View File

@@ -149,6 +149,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
return MistralAttention
if model_type == "gemma3_text":
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention
return Gemma3Attention
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"

View File

@@ -5,9 +5,14 @@ Patches to support multipack for mixtral
import torch
def patch_mixtral_moe_forward_zero3() -> None:
def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
import warnings
import torch.nn.functional as F
from axolotl.kernels.moe import backends as _moe_backends
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
def mlp_forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
hidden_states
@@ -21,21 +26,32 @@ def patch_mixtral_moe_forward_zero3() -> None:
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
backend = get_moe_backend_name(preferred)
if (
backend == MOEBackend.TORCH_GROUPED
and not _moe_backends._probe_torch_grouped()
):
warnings.warn(
"torch_grouped selected but not available; falling back to naive",
stacklevel=2,
)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(
routing_weights, self.top_k, dim=-1, sorted=False
)
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
topk_weight = topk_weight.to(hidden_states.dtype)
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
y = torch.empty_like(hidden_states)
hidden_states_rep = hidden_states.repeat_interleave(self.top_k, dim=0)
y = torch.empty_like(hidden_states_rep)
flat_topk_idx = topk_idx.view(-1)
for i in range(self.num_experts):
expert = self.experts[i]
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
sel = flat_topk_idx == i
if sel.any():
y[sel] = expert(hidden_states_rep[sel])
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
@@ -46,4 +62,23 @@ def patch_mixtral_moe_forward_zero3() -> None:
)
MixtralBlockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward
# Wrap forward to support optional torch_grouped backend via config
from axolotl.kernels.moe import torch_grouped as _tg
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
backend = get_moe_backend_name(preferred)
if backend == MOEBackend.TORCH_GROUPED and _tg.available():
def moe_forward_grouped(self, hidden_states: torch.Tensor) -> torch.Tensor:
bsz, seqlen, hdim = hidden_states.shape
y, router_logits = _tg.moe_ffn_forward_grouped(
hidden_states, self.gate, self.experts, self.top_k
)
if y is None:
return moe_forward(self, hidden_states)
return y, router_logits
MixtralSparseMoeBlock.forward = moe_forward_grouped
else:
MixtralSparseMoeBlock.forward = moe_forward

View File

@@ -0,0 +1,133 @@
import logging
import weakref
from functools import wraps
import torch
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
_LOG = logging.getLogger("axolotl.moe.patch")
def _patch_block_forward(block_cls, grouped_fn):
"""Replace block_cls.forward with grouped_fn preserving signature."""
block_cls.forward = grouped_fn
def apply_grouped_to_moe_blocks(cfg=None) -> None:
"""
Attempt to patch all known MoE block classes to use the torch_grouped backend
when cfg.moe_backend resolves to 'torch_grouped' and the op is available.
Falls back to original forwards otherwise.
"""
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
backend = get_moe_backend_name(preferred)
if backend != MOEBackend.TORCH_GROUPED:
_LOG.info(
f"moe_backend is '{backend}', not 'torch_grouped'; skipping grouped patches"
)
return
try:
from axolotl.kernels.moe import torch_grouped as _tg
except Exception:
_LOG.warning("torch_grouped backend import failed; skipping grouped patches")
return
if not _tg.available():
_LOG.warning(
"torch_grouped requested but unavailable (op smoke test failed); skipping grouped patches"
)
return
# Map of architecture key to (modeling module path, class name or list of class names)
model_mods = {
"mixtral": (
"transformers.models.mixtral.modeling_mixtral",
MOE_ARCH_BLOCK.get("mixtral"),
),
"qwen2_moe": (
"transformers.models.qwen2_moe.modeling_qwen2_moe",
MOE_ARCH_BLOCK.get("qwen2_moe"),
),
"qwen3_moe": (
"transformers.models.qwen3_moe.modeling_qwen3_moe",
MOE_ARCH_BLOCK.get("qwen3_moe"),
),
"jamba": (
"transformers.models.jamba.modeling_jamba",
MOE_ARCH_BLOCK.get("jamba"),
),
"deepseek_v2": (
"transformers.models.deepseek_v2.modeling_deepseek_v2",
MOE_ARCH_BLOCK.get("deepseek_v2"),
),
# Others may not follow standard paths; best-effort import
"dbrx": ("transformers.models.dbrx.modeling_dbrx", MOE_ARCH_BLOCK.get("dbrx")),
"jetmoe": (
"transformers.models.jetmoe.modeling_jetmoe",
MOE_ARCH_BLOCK.get("jetmoe"),
),
"gpt_oss": (
"transformers.models.gpt_oss.modeling_gpt_oss",
MOE_ARCH_BLOCK.get("gpt_oss"),
),
}
def make_grouped_forward(orig_forward):
@wraps(orig_forward)
def _grouped_forward(self, hidden_states: torch.Tensor, *args, **kwargs):
bsz, seqlen, hdim = hidden_states.shape
# expose parent block so grouped backend can access shared expert context
try:
self.experts._ax_parent_block_ref = weakref.ref(self)
except Exception:
pass
y, router_logits = _tg.moe_ffn_forward_grouped(
hidden_states, self.gate, self.experts, self.top_k
)
# One-time log per block instance indicating whether grouped engaged or fallback occurred
if not getattr(self, "_ax_grouped_wrapper_logged", False):
if y is None:
_LOG.warning(
"Grouped wrapper active but fell back to naive for %s",
self.__class__.__name__,
)
else:
_LOG.info(
f"Grouped wrapper engaged for {self.__class__.__name__} (top_k={self.top_k})"
)
self._ax_grouped_wrapper_logged = True
if y is None:
return orig_forward(self, hidden_states, *args, **kwargs)
return y, router_logits
return _grouped_forward
patched = 0
for key, (mod_path, cls_names) in model_mods.items():
if not cls_names:
continue
try:
import importlib
modeling = importlib.import_module(mod_path)
names = cls_names if isinstance(cls_names, list) else [cls_names]
for name in names:
if not hasattr(modeling, name):
continue
block_cls = getattr(modeling, name)
orig_forward = getattr(block_cls, "forward", None)
if orig_forward is None:
continue
_patch_block_forward(block_cls, make_grouped_forward(orig_forward))
patched += 1
_LOG.info(f"Patched MoE block for grouped GEMM: {mod_path}.{name}")
except Exception as e:
# Best effort; log and skip this entry
_LOG.warning(f"Skipping MoE patch for arch '{key}' ({mod_path}): {e}")
if patched == 0:
_LOG.warning(
"No MoE blocks patched for grouped GEMM; model may not use known MoE classes"
)
else:
_LOG.info(f"Grouped GEMM patches applied to {patched} MoE block class(es)")

View File

@@ -36,12 +36,17 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"glm",
"glm4",
"smollm3",
"granite",
"granitemoe",
"hunyuan_v1_dense",
"hunyuan_v1_moe",
"gpt_oss",
"arcee",
"seed_oss",
]
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
def patch_for_multipack(model_type, model_name=None, has_remote_code=False, cfg=None):
if has_remote_code:
patch_remote(model_name)
elif hasattr(transformers, "modeling_flash_attention_utils"):
@@ -52,7 +57,7 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
patch_mixtral_moe_forward_zero3(cfg)
def patch_remote(model_name):

View File

@@ -8,6 +8,94 @@ from typing import List
import torch
class DeepSpeedTiledMLPMoE(torch.autograd.Function):
@staticmethod
def forward(
ctx,
fn,
self,
x,
shards,
compute_params,
) -> torch.Tensor:
ctx.fn = fn
ctx.self = self
ctx.shards = shards
ctx.compute_params = [p for p in compute_params if p.requires_grad]
ctx.save_for_backward(x)
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
with torch.no_grad():
output_shards = [fn(self, x_shard) for x_shard in x_shards]
ctx.is_tuple_output = isinstance(output_shards[0], tuple)
if isinstance(output_shards[0], tuple):
tuple_dim_idx = [1, 0]
output_unsharded = tuple(
torch.cat(
[output_shard[i] for output_shard in output_shards],
dim=tuple_dim_idx[i],
)
for i in range(len(output_shards[0]))
)
else:
output_unsharded = torch.cat(output_shards, dim=1)
return output_unsharded
@staticmethod
def backward(ctx, *grads) -> torch.Tensor:
fn = ctx.fn
(x,) = ctx.saved_tensors
self = ctx.self
shards = ctx.shards
compute_params = ctx.compute_params
is_tuple_output = ctx.is_tuple_output
x_requires_grad = x.requires_grad
x = x.detach()
# detach() unsets `x.requires_grad`, so restore it
x.requires_grad_(x_requires_grad)
incoming_grad = grads[0]
x_grad = torch.zeros_like(x)
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
shard_step = x_shards[0].numel()
for i, x_shard in enumerate(x_shards):
# Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run
if compute_params is not None:
if i + 1 < shards:
for param in compute_params:
param.ds_grad_is_ready = False
else:
# last shard, can add the grad
for param in compute_params:
param.ds_grad_is_ready = True
x_shard.requires_grad_(x_requires_grad)
shard_offset = i * shard_step
x_shard.grad = (
x_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
incoming_grad_shard = (
incoming_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
with torch.enable_grad():
output = fn(self, x_shard)
if is_tuple_output:
torch.autograd.backward(output[0], incoming_grad_shard)
else:
torch.autograd.backward(output, incoming_grad_shard)
return (None, None, x_grad, None, None)
class TiledMLP(torch.autograd.Function):
"""
TiledMLP implementation using gradient hooks
@@ -31,7 +119,18 @@ class TiledMLP(torch.autograd.Function):
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
with torch.no_grad():
output_shards = [fn(self, x_shard) for x_shard in x_shards]
output_unsharded = torch.cat(output_shards, dim=1)
ctx.is_tuple_output = isinstance(output_shards[0], tuple)
if isinstance(output_shards[0], tuple):
tuple_dim_idx = [1, 0]
output_unsharded = tuple(
torch.cat(
[output_shard[i] for output_shard in output_shards],
dim=tuple_dim_idx[i],
)
for i in range(len(output_shards[0]))
)
else:
output_unsharded = torch.cat(output_shards, dim=1)
return output_unsharded
@@ -42,6 +141,7 @@ class TiledMLP(torch.autograd.Function):
self = ctx.self
shards = ctx.shards
compute_params = ctx.compute_params
is_tuple_output = ctx.is_tuple_output
x_requires_grad = x.requires_grad
x = x.detach()
@@ -76,7 +176,10 @@ class TiledMLP(torch.autograd.Function):
with torch.enable_grad():
output = fn(self, x_shard)
torch.autograd.backward(output, incoming_grad_shard)
if is_tuple_output:
torch.autograd.backward(output[0], incoming_grad_shard)
else:
torch.autograd.backward(output, incoming_grad_shard)
# Clean up hooks
grad_accumulator.cleanup()

View File

@@ -17,7 +17,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
TiledMLP as DeepSpeedTiledMLP,
)
from axolotl.monkeypatch.tiled_mlp.base import TiledMLP
from axolotl.monkeypatch.tiled_mlp.base import DeepSpeedTiledMLPMoE, TiledMLP
try:
# Dynamically import the module and MLP class
@@ -64,7 +64,10 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
for p in self._compute_params
)
) or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
self._tiled_mlp_dist_impl = DeepSpeedTiledMLP
if model_type == "gpt_oss":
self._tiled_mlp_dist_impl = DeepSpeedTiledMLPMoE
else:
self._tiled_mlp_dist_impl = DeepSpeedTiledMLP
else:
self._tiled_mlp_dist_impl = TiledMLP

View File

@@ -28,15 +28,6 @@ PATCHED_EVAL_CODE = {
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
}
ORIGINAL_FSDP2_CODE = """
model.eval()
"""
PATCHED_FSDP2_CODE = """
if hasattr(model, "eval") and callable(model.eval):
self.model.eval()
"""
ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"
@@ -46,17 +37,11 @@ def check_evaluation_loop_is_patchable() -> bool:
return all(value in evaluation_loop_source for value in ORIGINAL_EVAL_CODE.values())
def check_evaluation_loop_is_fsdp2_patchable() -> bool:
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
return ORIGINAL_FSDP2_CODE in evaluation_loop_source
def patch_evaluation_loop(patch_fsdp2: bool):
def patch_evaluation_loop():
"""Patch the evaluation_loop method."""
# Check if already patched
if hasattr(Trainer, "_original_evaluation_loop"):
LOG.info("Trainer.evaluation_loop already patched")
LOG.debug("Trainer.evaluation_loop already patched")
return
# Check if the patterns exist
@@ -75,13 +60,6 @@ def patch_evaluation_loop(patch_fsdp2: bool):
ORIGINAL_EVAL_CODE["array"], PATCHED_EVAL_CODE["array"]
)
# Apply FSDP2 eval guard patch if needed
if patch_fsdp2 and ORIGINAL_FSDP2_CODE in evaluation_loop_source:
evaluation_loop_source = evaluation_loop_source.replace(
ORIGINAL_FSDP2_CODE, PATCHED_FSDP2_CODE
)
LOG.info("Applied FSDP2 eval guard patch to evaluation_loop")
# Rename the function to avoid conflicts
evaluation_loop_source = evaluation_loop_source.replace(
"def evaluation_loop(",
@@ -106,7 +84,7 @@ def patch_evaluation_loop(patch_fsdp2: bool):
)
exec(evaluation_loop_source, globals())
LOG.info("Patched Trainer.evaluation_loop with nanmean loss calculation")
LOG.debug("Patched Trainer.evaluation_loop with nanmean loss calculation")
Trainer.evaluation_loop = axolotl_evaluation_loop
@@ -157,5 +135,5 @@ def patch_maybe_log_save_evaluate():
)
exec(maybe_log_source, globals())
LOG.info("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
LOG.debug("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate

View File

@@ -75,7 +75,7 @@ class PromptTokenizingStrategy(abc.ABC):
) -> BatchEncoding:
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
if not prompt:
LOG.warning("Empty text requested for tokenization.")
LOG.warning_once("Empty text requested for tokenization.")
return empty
result = self.tokenizer(

View File

@@ -30,11 +30,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.integrations.base import PluginManager
from axolotl.loaders import (
ModelLoader,
load_processor,
load_tokenizer,
)
from axolotl.loaders import ModelLoader, load_processor, load_tokenizer
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
@@ -200,10 +196,11 @@ def execute_training(
)
)
LOG.info("Starting trainer...")
# TODO: disabling for now as not compatible with FSDP2 + torchao low bit optimizers
# if cfg.bf16:
# torch.set_default_dtype(torch.bfloat16)
LOG.info("Starting trainer...")
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
plugin_manager = PluginManager.get_instance()
@@ -234,16 +231,15 @@ def save_trained_model(
# handle QAT
if cfg.qat:
from axolotl.utils.quantization import convert_qat_model_for_ptq
from axolotl.utils.quantization import convert_qat_model
LOG.info("Processing QAT model for saving...")
convert_qat_model_for_ptq(
convert_qat_model(
model,
quantize_embedding=cfg.qat.quantize_embedding,
)
LOG.info(
"QAT modules have been converted for PTQ. Please ensure you quantize "
"your model weights with `axolotl quantize`."
"QAT usage note: please ensure you quantize your model fine-tuned using QAT by running `axolotl quantize`"
" with the same config which you used for training."
)
# Handle ReLoRA early return case
if cfg.relora:
@@ -337,9 +333,7 @@ def save_trained_model(
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin
from axolotl.integrations.llm_compressor.utils import (
save_compressed_model,
)
from axolotl.integrations.llm_compressor.utils import save_compressed_model
save_compressed_model(
model=model,
@@ -416,7 +410,9 @@ def save_initial_configs(
# Pre-save the tokenizer and model configs
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
tokenizer.save_pretrained(str(output_dir))
tokenizer.save_pretrained(
str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files
)
if hasattr(model, "config"):
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
model.config.save_pretrained(str(output_dir))
@@ -592,6 +588,9 @@ def train(
# Save the trained model and cleanup
save_trained_model(cfg, trainer, model, safe_serialization)
tokenizer.save_pretrained(
str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files
)
create_model_card(cfg, trainer)
if not cfg.use_ray:
cleanup_distributed()

Some files were not shown because too many files have changed in this diff Show More