Compare commits

...

42 Commits

Author SHA1 Message Date
Wing Lian
598c965043 use train_loss for sp test 2026-03-22 12:00:55 -04:00
Wing Lian
a96733930e retry and more info on download failure 2026-03-22 11:09:33 -04:00
Wing Lian
6130e40c37 fix flaky tests; should be using train loss from final step rather than final avg train loss 2026-03-22 10:38:46 -04:00
Wing Lian
5b2e3f00ce fix: handle connection errors when checking user whoami (#3529) 2026-03-22 09:11:17 -04:00
Wing Lian
fc3b3d1d4e synthetic datasets for benchmarking and testing (#3518) [skip ci]
* synthetic datasets for benchmarking and testing

* fix synthetic dataset parse from config and add tests

* use type=_synthetic
2026-03-21 22:47:26 -04:00
Wing Lian
c9df6efdc2 support offloading layers to CPU (#3512) [skip ci]
* support offloading layers to CPU

* chore: lint

* revert change

* update docs
2026-03-21 22:47:02 -04:00
Wing Lian
0ee98a0309 fix token state json and mistral tokenizer issue (#3522) [skip ci]
* fix token state json and mistral tokenizer issue

* centralize constants

* forgot to commit constants file

* Fix weakref in pickling relora state dict

* make curl a bit quieter so it doesn't log 2K lines

* fix path traversal for olmoe test

* more test fixes that weren't flagged previously

* chore: lint

* skip tests that fail b/c of OutOfResources

* scattermoe as slow tests

* update fbgemm-genai for torch 2.10
2026-03-21 22:46:10 -04:00
Wing Lian
2c05847a5f reduce autotune search space (#3525) [skip ci]
* reduce autotune search space

* consistent docstrings
2026-03-21 18:30:15 -04:00
Wing Lian
b0294b3427 handle qwen3.5 moe loading (#3523) [skip ci] 2026-03-20 09:25:16 -04:00
Avaya Aggarwal
1bcfc08c90 feat: add support and end-to-end tests for multiple custom optimizers… (#3457) [skip ci]
* feat: add support and end-to-end tests for multiple custom optimizers including Optimi AdamW, ADOPT AdamW, Muon, Dion, Schedule-Free AdamW, CAME PyTorch, and Flash AdamW.

* feat: Add standalone flashoptim integration test and E2E tests for various custom optimizers including FlashAdamW, FlashAdam, FlashSGD, FlashSGDW, FlashLion, optimi_adamw, adopt_adamw, muon, dion, and schedule_free_adamw.

* feat: introduce Pydantic schema validation for dataset, attention, and training configurations.

* feat: add e2e tests for custom optimizers including optimi_adamw, adopt_adamw, muon, dion, schedule_free_adamw, came_pytorch, and flash optimizers.

* test: add e2e tests for custom optimizers including optimi_adamw, adopt_adamw, muon, dion, schedule_free_adamw, came_pytorch, and flash optimizers.

* test: fix assertion in flash optimizers test to compare class names directly

* fix: address PR review - reuse require_torch_2_7_0 decorator, remove fsdp_config.version check, extract shared FSDP version helper, remove unused imports and optim_args

* chore: lint

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-03-20 08:24:44 -04:00
NanoCode012
5a5cf30b26 fix: add dequant bf16 repo (#3507) [skip ci] 2026-03-20 17:11:46 +07:00
Avaya Aggarwal
7ddfb2d8a0 cleanup: remove dead SDPA patches (#3488) [skip ci]
Transformers 5.x routes attention through sdpa_attention.py and no longer
calls the _prepare_4d_causal_attention_mask* or _expand_mask functions that
these patches targeted. This makes the following patches dead code:

- llama_patch_multipack.py (patched _prepare_4d_causal_attention_mask*)
- llama_expand_mask.py (patched _expand_mask, never called)
- Related utility functions in monkeypatch/utils.py

Closes axolotl-ai-cloud/axolotl#3331
2026-03-20 17:10:41 +07:00
Owen Arliawan
c57acef2c7 Qwen3.5-MoE example config with lora_target_modules regex (#3515) [skip ci]
* lora target modules with regex

* updates

* fsdp for non moe

* update wording

* chore: cleanup and lint

* chore: cleanup docs from merge

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-03-20 16:52:46 +07:00
Lorenzo Baraldi
038ffe3f26 fix: solved double sequence partition from SequenceParallelContextManager and Accelerate's native CP (#3498) 2026-03-20 16:27:24 +07:00
VED
c13cb7c853 feat: add nemotron config (#3506)
* nemotron config exp

* Update examples/nemotron/nemotron-mini-4b-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2026-03-20 16:23:42 +07:00
VED
b3823cc6b0 fix: gemma3 configs (#3500) [skip ci]
* gemma fft , text fix

* good lint
2026-03-20 16:14:06 +07:00
VED
113d275bd9 qwen docs + new config (#3499) [skip ci]
* qwen docs + new config

* docss lint

* simplify comments

* read me

* lint comments

* Update docs/multimodal.qmd

* Update docs/multimodal.qmd

* Update examples/qwen3.5/9b-fft-vision.yaml

* chore: fix link and incorrect points

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-03-20 16:13:34 +07:00
VED
7920fe74ec fix num_labels= 1 test fail (#3493) [skip ci]
* trl_num_lables=1

* casual num_lables=1,rwd model

* lint
2026-03-20 16:12:23 +07:00
Wing Lian
1fc86d5295 Scattermoe LoRA optimizations (#3513)
* optimize moe + lora

* more scattermoe optims

* selective dequant

* add correctness unit tests and benchmarks for scattermoe + lora

* handle base+lora split kernel for older moe models

* chore: lint

* fix casting for H200 and B200

* register pressure estimation and pruning for h200/b200

* use soft limit for pruning

* qkv patch for qwen3.5moe

* support text_model for qwen3.5 moe

* nesting of qwen3

* use udpated cce with zero3 support

* Fix decomposed backward for QKV and O projections

eliminates B @ A materialization in LoRA attention backward, replacing full [out, in] matmuls with two small [T, R] matmuls.
2026-03-19 23:07:42 -04:00
Wing Lian
bb483ad4c4 make the CI fail GitHub Actions on test failures (#3517)
* make the CI fail GitHub Actions on test failures

* use model bundle

* install zstd for compressed model artifact
2026-03-19 08:29:24 -04:00
Wing Lian
163bd4dd5a use custom triton kernels for entropy from logits and selective softmax (#3510)
* use custom triton kernels for entropy from logits and selective softmax

* PR comments fixes

* fix out of bounds, include tests, include benchmarks

* chore: lint
2026-03-19 02:02:43 -04:00
Wing Lian
f291ac029c fix for flaky tests in lora ops kernels w autotune (#3511) [skip ci]
* fix for flaky tests in lora ops kernels w autotune

* attempt 2 to fix
2026-03-19 01:18:47 -04:00
Wing Lian
5ef3f28340 Support for Async GRPO (#3486)
* async grpo support

* implement data producer

* use fast async

* handle call to create data producer

* fix liger kernel setup

* fix replay buffer

* chore: lint

* make gpus go brrr

* chore: lint

* inplace div_, unwrap model for logits in bf16

* fuse selective softmax and empty cuda cache on each scoring step

* remove waiting for synch time and fix race

* make fp8 work and allow lora kernels w rl

* grpo with lora vllm sync and fixes for sharded distributed

* update docs

* more patches so it works against trl main

* address PR feedback for corerabbit
2026-03-17 11:42:47 -04:00
Aarush
999b3fec2e fix: replace shell=True subprocess with argument list in modal CLI (#3487)
* fix: replace shell=True subprocess with argument list in modal CLI

Using shell=True with a formatted string containing docker_image
(a user-controlled value) is a command injection risk (Bandit B602).
Replace with an argument list, which passes args directly to the
process without shell interpretation, removing the nosec annotation.

* fix: add nosec annotation to suppress bandit B603/B607 warnings

Removing shell=True (B602) surfaces B603 (subprocess without shell)
and B607 (partial executable path for 'docker'). Use bare # nosec
to suppress both, consistent with other nosec usages in the codebase.
2026-03-17 08:53:13 -04:00
Wing Lian
8f3fb517b3 consolidate behavioud of routing in scattermoe kernels (#3475)
* consolidate behavioud of routing in scattermoe kernels

* collect telemetry on best chosen autotuned kernel

* properly collect data

* Fix property name and get smem too

* handle issues raised by coderabbit

* add tests for parity before refactoring
2026-03-16 23:47:40 -04:00
Wing Lian
830e9f7eaf automatically enable tf32 if supported (#3473) [skip ci]
* automatically enable tf32 if supported

* update fixtures

* handle only when True

* Address CR comments

* address readability from pr comment

* simplify
2026-03-16 23:47:00 -04:00
NanoCode012
d230cbbde3 chore(doc): update readme (#3503) [skip ci] 2026-03-17 09:43:24 +07:00
NanoCode012
a098df527b feat: add Mistral Small 4 (#3502)
* feat: add mistral small 4

* fix: update mistral common

* fix: deepcopy when passing in tokenizer

* feat: add doc on reasoning and thinking section

* fix: don't use custom tokenizer and quantize experts

* chore: update docs and configs

* chore: update doc to follow official name

* feat: update cce to include mistral4

* chore: move

* fix: naming

* fix: test mock breaking get_text_config check

* fix: enable CCE and add expert block targetting to configs

* chore: docs

* fix: use act checkpointing

* chore: doc

* chore: docs

* chore: docs
2026-03-17 09:39:05 +07:00
NanoCode012
7da5f94379 feat: add FA4 (#3481)
* feat: add FA4

* chore: update docs

* fix: recommend FA4 for those with compatible devices

* fix: adjust import check and add head_dim check

* chore: add limitation to doc

* fix: log warning and quit if cannot import validator

* chore: simplify

* fix: add caveat with FA2 shadow dir
2026-03-16 00:13:18 -04:00
NanoCode012
4a5876df7a fix: explicit set workflow permission and move secrets to necessary (#3484) [skip ci]
* fix: explicit set workflow permission and move secrets to necessary
steps only

* fix: comment

* fix: more permission restrict

* chore: add read for pypi
2026-03-16 00:13:05 -04:00
Aarush
defee62d99 fix: fix CONTRIBUTING.md placeholders, bare except clauses, and add convert.py tests (#3485) [skip ci]
* docs: fix codestyle placeholders in CONTRIBUTING.md

Replace unresolved {codestyle} and {URLofCodestyle} template
variables with Ruff, the project's actual linter/formatter
as configured in .pre-commit-config.yaml.

* fix: replace bare except clauses with specific exception types

- quantization.py: use except ImportError for optional torchao imports
  (consistent with line 48 which already uses ImportError correctly)
- cli/config.py: use except (RuntimeError, AssertionError) for CUDA
  device property query

Prevents masking unrelated errors like KeyboardInterrupt or SystemExit.

* test: add unit tests for convert.py JSON/JSONL utilities

Cover FileReader, FileWriter, StdoutWriter, JsonParser,
JsonlSerializer, and JsonToJsonlConverter with 8 test cases
including roundtrip and edge case (empty list) scenarios.

Previously this module had zero test coverage.

* fix: address CodeRabbit review feedback

- quantization.py: catch (ImportError, RuntimeError) for optional
  torchao imports; CUDA wheel/GPU mismatches raise RuntimeError,
  not ImportError
- convert.py: remove unused output_file_path parameter from
  JsonToJsonlConverter.convert() — FileWriter already holds the
  output path from construction
- tests/test_convert.py: update call site to match new signature
2026-03-16 00:12:40 -04:00
VED
f56efdb4ab fix: high eval loss w/ sample packing (#3478) [skip ci]
* check if eval_sp

* radable condition
2026-03-15 22:11:23 -04:00
NanoCode012
d8a646c80d chore: logging cleanup (#3482) [skip ci] 2026-03-15 22:10:57 -04:00
VED
a806704e94 moe quant patch for merge miss match (#3483)
* moe quant patch for merge miss match

* lint

* revert test + fix moe patch

* comment fixxes

* e2e tests

* mismatch fixx tested

* mis match fix wwith vllm compatablity + test

* comment lint

* fix: missing os import, duplicate no op

* chore: simplify comments

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-03-15 22:10:30 -04:00
Wing Lian
d8a05744d7 Reverts commits 79908b3c6, 083c5a042, e1ff75624, ff77fa248. (#3496)
The non-root user approach had multiple issues with RunPod
compatibility, sudo PATH handling, and tmux in exec sessions.
Restoring root as the default user for now.
2026-03-13 11:54:09 -04:00
Wing Lian
ff77fa2488 preserve env for root -> ubuntu user (#3495) 2026-03-13 10:19:34 -04:00
Wing Lian
e1ff756245 become the ubuntu user when root logs in (#3494) 2026-03-13 09:06:54 -04:00
Wing Lian
083c5a0421 check ubuntu user and set uv python dir (#3492) 2026-03-12 23:20:54 -04:00
Wing Lian
79908b3c6e use ubuntu user instead of root for uv docker images (#3491) 2026-03-12 20:41:13 -04:00
Wing Lian
819b157c7b swap around what we're building for docker (#3490)
* remove cloud configuration we don't base image for

* but we do want it for uv
2026-03-11 21:45:13 -04:00
Wing Lian
fccc712dae builds for py312-cu128-torch2.9.1 (#3489) 2026-03-11 20:09:03 -04:00
NanoCode012
23ad40bdd5 fix: disable async load when loading quantized bnb 2026-03-11 13:18:27 +07:00
170 changed files with 13444 additions and 689 deletions

View File

@@ -68,7 +68,7 @@ You can skip certain CI checks by including specific keywords in your commit mes
### Code Style
axolotl uses [{codestyle}]({URLofCodestyle}) as its code style guide. Please ensure that your code follows these guidelines.
axolotl uses [Ruff](https://docs.astral.sh/ruff/) as its code style guide. Please ensure that your code follows these guidelines.
Use the pre-commit linter to ensure that your code is formatted consistently.
```bash
@@ -83,6 +83,6 @@ Write clear and concise commit messages that briefly describe the changes made i
- [GitHub Help](https://help.github.com/)
- [GitHub Pull Request Documentation](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests)
- [{codestyle}]({URLofCodestyle})
- [Ruff](https://docs.astral.sh/ruff/)
Thank you once again for your interest in contributing to axolotl. We look forward to collaborating with you and creating an even better project together!

View File

@@ -15,6 +15,9 @@ on:
- '.github/workflows/base.yml'
workflow_dispatch:
permissions:
contents: read
jobs:
build-base:
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
@@ -124,7 +127,7 @@ jobs:
images: |
axolotlai/axolotl-base
- name: Login to Docker Hub
uses: docker/login-action@v2
uses: docker/login-action@v3
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
@@ -132,7 +135,7 @@ jobs:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v4
uses: docker/build-push-action@v5
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
@@ -173,6 +176,14 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -239,7 +250,7 @@ jobs:
images: |
axolotlai/axolotl-base-uv
- name: Login to Docker Hub
uses: docker/login-action@v2
uses: docker/login-action@v3
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
@@ -247,7 +258,7 @@ jobs:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v4
uses: docker/build-push-action@v5
with:
context: .
file: ./docker/${{ matrix.dockerfile }}

View File

@@ -13,6 +13,9 @@ on:
- ".pre-commit-config.yaml"
workflow_dispatch:
permissions:
contents: read
jobs:
pre-commit:
name: pre-commit

View File

@@ -8,6 +8,9 @@ on:
- "v*"
workflow_dispatch:
permissions:
contents: read
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
@@ -110,6 +113,12 @@ jobs:
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
@@ -174,6 +183,7 @@ jobs:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: 128
@@ -259,6 +269,7 @@ jobs:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: 128
@@ -266,6 +277,12 @@ jobs:
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 128
@@ -326,6 +343,7 @@ jobs:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: 128

View File

@@ -20,6 +20,9 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
permissions:
contents: read
env:
MODAL_IMAGE_BUILDER_VERSION: "2025.06"
@@ -78,8 +81,9 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run -m cicd.multigpu

View File

@@ -5,6 +5,9 @@ on:
schedule:
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
permissions:
contents: read
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}

View File

@@ -5,6 +5,8 @@ on:
- cron: '0 0 1 * *' # Run monthly
workflow_dispatch: # Manual kickoff
permissions: {}
jobs:
auto-update:
runs-on: ubuntu-latest

View File

@@ -3,9 +3,11 @@ name: publish pypi
on:
push:
tags:
- 'v*'
- "v*"
workflow_dispatch:
permissions: {}
jobs:
setup_release:
name: Create Release
@@ -28,7 +30,8 @@ jobs:
name: pypi
url: https://pypi.org/p/axolotl
permissions:
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
contents: read
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
steps:
- name: Check out repository code
uses: actions/checkout@v4
@@ -46,7 +49,7 @@ jobs:
- name: Extract tag name
id: tag
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
run: echo "TAG_NAME=$(echo $GITHUB_REF | cut -d / -f 3)" >> "$GITHUB_OUTPUT"
- name: Update version in VERSION file
run: |

View File

@@ -8,6 +8,9 @@ on:
paths:
- '.github/workflows/tests-nightly.yml'
permissions:
contents: read
jobs:
pre-commit:
name: pre-commit
@@ -156,8 +159,9 @@ jobs:
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run cicd.e2e_tests
docker-e2e-multigpu-tests:
@@ -198,7 +202,8 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run cicd.multigpu

View File

@@ -28,6 +28,9 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
permissions:
contents: read
env:
TRANSFORMERS_IS_CI: "yes"
@@ -303,9 +306,10 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run cicd.e2e_tests
@@ -371,9 +375,10 @@ jobs:
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run cicd.e2e_tests
@@ -413,7 +418,6 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.cleanup

View File

@@ -30,7 +30,7 @@
## 🎉 Latest Updates
- 2026/03:
- New model support has been added in Axolotl for [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
- New model support has been added in Axolotl for [Mistral Small 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4), [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
- 2026/02:
- [ScatterMoE LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3410) support. LoRA fine-tuning directly on MoE expert weights using custom Triton kernels.
@@ -75,7 +75,7 @@ Features:
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support.
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [SageAttention](https://github.com/thu-ml/SageAttention), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention 2/3/4](https://docs.axolotl.ai/docs/attention.html#flash-attention), [Xformers](https://docs.axolotl.ai/docs/attention.html#xformers), [Flex Attention](https://docs.axolotl.ai/docs/attention.html#flex-attention), [SageAttention](https://docs.axolotl.ai/docs/attention.html#sageattention), [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels), [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.

View File

@@ -128,11 +128,9 @@ quartodoc:
- monkeypatch.mistral_attn_hijack_flash
- monkeypatch.multipack
- monkeypatch.relora
- monkeypatch.llama_expand_mask
- monkeypatch.lora_kernels
- monkeypatch.utils
- monkeypatch.btlm_attn_hijack_flash
- monkeypatch.llama_patch_multipack
- monkeypatch.stablelm_attn_hijack_flash
- monkeypatch.trainer_fsdp_optim
- monkeypatch.transformers_fa_utils

208
benchmarks/bench_entropy.py Normal file
View File

@@ -0,0 +1,208 @@
"""Benchmark for entropy_from_logits Triton kernel vs original chunked implementation.
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_entropy.py
"""
import gc
import statistics
import torch
import torch.nn.functional as F
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
V = 151936 # Qwen vocab
WARMUP = 5
BENCH_ITERS = 20
MEM_ITERS = 10
def entropy_from_logits_original(logits: torch.Tensor, chunk_size: int = 128):
"""Original chunked implementation (reference)."""
original_shape = logits.shape[:-1]
num_classes = logits.shape[-1]
flat_logits = logits.reshape(-1, num_classes)
entropies = []
for chunk in flat_logits.split(chunk_size, dim=0):
logps = F.log_softmax(chunk, dim=-1)
chunk_entropy = -(torch.exp(logps) * logps).sum(-1)
entropies.append(chunk_entropy)
return torch.cat(entropies, dim=0).reshape(original_shape)
def _clean_gpu():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.synchronize()
def profile_time(fn, logits, n_iters=BENCH_ITERS):
for _ in range(WARMUP):
out = fn(logits, chunk_size=128)
del out
torch.cuda.synchronize()
times = []
for _ in range(n_iters):
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
out = fn(logits, chunk_size=128)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
del out
return times
def profile_memory(fn, logits, n_iters=MEM_ITERS):
for _ in range(WARMUP):
out = fn(logits, chunk_size=128)
del out
torch.cuda.synchronize()
peaks = []
for _ in range(n_iters):
_clean_gpu()
base = torch.cuda.max_memory_allocated()
out = fn(logits, chunk_size=128)
torch.cuda.synchronize()
peaks.append(torch.cuda.max_memory_allocated() - base)
del out
return [p / 1e6 for p in peaks]
def fmt(values, unit=""):
mean = statistics.mean(values)
std = statistics.stdev(values) if len(values) > 1 else 0.0
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
def benchmark_contiguous():
print("=" * 60)
print(
f"CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
)
print("=" * 60)
configs = [
(1, 2048),
(1, 8192),
(1, 16384),
(4, 4096),
(8, 2048),
(16, 2048),
(16, 4096),
]
for B, L in configs:
mem_gb = B * L * V * 2 / 1e9
if mem_gb > 28:
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
print(f"{'' * 60}")
torch.manual_seed(42)
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
t_orig = profile_time(entropy_from_logits_original, logits)
t_triton = profile_time(entropy_from_logits, logits)
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" TIME (ms):")
print(f" original: {fmt(t_orig, 'ms')}")
print(f" triton: {fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(entropy_from_logits_original, logits)
m_triton = profile_memory(entropy_from_logits, logits)
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" MEMORY (peak overhead):")
print(f" original: {fmt(m_orig, 'MB')}")
print(f" triton: {fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del logits
_clean_gpu()
def benchmark_noncontiguous():
print("\n" + "=" * 60)
print(
f"NON-CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
)
print("=" * 60)
configs = [
(4, 2048, "transpose"),
(4, 8192, "transpose"),
(8, 2048, "transpose"),
(4, 4096, "slice_batch"),
]
for B, L, method in configs:
torch.manual_seed(42)
if method == "transpose":
raw = torch.randn(L, B, V, device="cuda", dtype=torch.bfloat16)
logits_nc = raw.transpose(0, 1)
raw_gb = L * B * V * 2 / 1e9
elif method == "slice_batch":
raw = torch.randn(B * 2, L, V, device="cuda", dtype=torch.bfloat16)
logits_nc = raw[::2]
raw_gb = B * 2 * L * V * 2 / 1e9
else:
continue
if raw_gb > 28:
print(f"\n skip B={B}, L={L}, {method} ({raw_gb:.1f} GB)")
del raw, logits_nc
torch.cuda.empty_cache()
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B}, L={L} {method} ({N} rows, raw {raw_gb:.2f} GB)")
print(f"{'' * 60}")
def original_with_copy(logits, chunk_size=128):
return entropy_from_logits_original(
logits.contiguous(), chunk_size=chunk_size
)
t_orig = profile_time(original_with_copy, logits_nc)
t_triton = profile_time(entropy_from_logits, logits_nc)
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" TIME (ms):")
print(f" orig+copy: {fmt(t_orig, 'ms')}")
print(f" triton-strided:{fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(original_with_copy, logits_nc)
m_triton = profile_memory(entropy_from_logits, logits_nc)
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" MEMORY (peak overhead):")
print(f" orig+copy: {fmt(m_orig, 'MB')}")
print(f" triton-strided:{fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del raw, logits_nc
_clean_gpu()
if __name__ == "__main__":
benchmark_contiguous()
benchmark_noncontiguous()

View File

@@ -0,0 +1,284 @@
"""Benchmark for ScatterMoE LoRA Triton kernels.
Measures forward, backward dX, and backward dA/dB kernels at common MoE
model shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter,
and full fwd+bwd autograd throughput.
Usage:
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B
"""
import argparse
import gc
import time
from functools import partial
import torch
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
lora_ops,
ops as base_ops,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
flatten_sort_count,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
ScatterMoELoRA,
)
DEVICE = "cuda"
DTYPE = torch.bfloat16
WARMUP = 5
ITERS = 20
# ─── Model configs ──────────────────────────────────────────────────────────
BUILTIN_CONFIGS = {
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
"Qwen3-30B-A3B": (128, 2048, 768, 8),
"OLMoE-1B-7B": (64, 2048, 1024, 8),
"Mixtral-8x7B": (8, 4096, 14336, 2),
}
def _resolve_config(spec):
"""Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs."""
key = spec.lower().replace("/", "-")
for name, cfg in BUILTIN_CONFIGS.items():
if key in name.lower() or name.lower() in key:
return name, cfg
from transformers import AutoConfig
hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)
if callable(getattr(hf_cfg, "get_text_config", None)):
tc = hf_cfg.get_text_config()
if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type:
hf_cfg = tc
hidden = hf_cfg.hidden_size
inter = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
experts = (
getattr(hf_cfg, "num_experts", None)
or getattr(hf_cfg, "num_local_experts", None)
or getattr(hf_cfg, "n_routed_experts", None)
)
top_k = (
getattr(hf_cfg, "num_experts_per_tok", None)
or getattr(hf_cfg, "num_experts_per_token", None)
or 2
)
name = spec.split("/")[-1]
return name, (experts, hidden, inter, top_k)
# ─── Benchmark helpers ──────────────────────────────────────────────────────
def _clean():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
def _bench(fn, warmup=WARMUP, iters=ITERS):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
times = []
for _ in range(iters):
torch.cuda.synchronize()
t0 = time.perf_counter()
fn()
torch.cuda.synchronize()
times.append((time.perf_counter() - t0) * 1000)
times.sort()
return times[len(times) // 2]
def _setup(num_experts, K, N, T, top_k, R):
torch.manual_seed(42)
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02
lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01
lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01
logits = torch.randn(T, num_experts, device=DEVICE)
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
sei, ssi, eo = flatten_sort_count(top_idx, num_experts)
gx = base_ops.group(x, ssi, fan_out=top_k)
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy
# ─── Kernel wrappers (avoid B023 loop-variable capture) ──────────────────────
def _call_fwd(x, W, sei, ssi, top_k, lA, lB):
return lora_ops.scatter2scatter_lora(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=top_k,
lora_A=lA,
lora_B=lB,
scaling=2.0,
)
def _call_base(x, W, sei, ssi, top_k):
return base_ops.scatter2scatter(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=top_k,
)
def _call_dx(dy, W, sei, ssi, lA, lB):
return lora_ops.scatter2scatter_lora_dX(
DY=dy,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=1,
lora_A=lA,
lora_B=lB,
scaling=2.0,
dy_grouped=True,
dx_grouped=False,
)
def _call_bwd(dy, gx, lA, lB, eo, num_experts):
return lora_ops.group_bwd_lora(
DY=dy,
X=gx,
lora_A=lA,
lora_B=lB,
expert_offsets=eo,
E=num_experts,
scaling=2.0,
)
# ─── Main ────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark")
parser.add_argument(
"--models",
"-m",
nargs="+",
help="Model names or HF IDs (default: all builtins)",
)
parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64])
parser.add_argument("--seq-len", "-T", type=int, default=2048)
args = parser.parse_args()
T = args.seq_len
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"T={T}, ranks={args.ranks}\n")
if args.models:
configs = [_resolve_config(m) for m in args.models]
else:
configs = list(BUILTIN_CONFIGS.items())
for model_name, (num_experts, hidden, inter, top_k) in configs:
print(f"{'=' * 70}")
print(f" {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}")
print(f"{'=' * 70}")
for R in args.ranks:
for proj, K, N in [("gate_up", hidden, 2 * inter), ("down", inter, hidden)]:
_clean()
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(
num_experts, K, N, T, top_k, R
)
# Forward with LoRA (auto-dispatched: fused or split)
dispatch = (
"split"
if (
num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD
)
else "fused"
)
t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB))
t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k))
t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB))
t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts))
total = t_fwd + t_dx + t_bwd
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
print(
f" R={R:>2} {proj:<8} "
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
f"base={t_base:>6.2f}ms "
f"(+{overhead * 100:.0f}%) "
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
f"total={total:>6.2f}ms"
)
# Full autograd fwd+bwd with memory measurement
x_ag = x.clone().requires_grad_(True)
lA_ag = lA.clone().requires_grad_(True)
lB_ag = lB.clone().requires_grad_(True)
def _run_autograd(
_x=x_ag,
_W=W,
_k=top_k,
_sei=sei,
_ssi=ssi,
_eo=eo,
_lA=lA_ag,
_lB=lB_ag,
):
out = ScatterMoELoRA.apply(
_x,
_W,
_k,
_sei,
_ssi,
_eo,
_lA,
_lB,
2.0,
None,
None,
False,
False,
True,
False,
)
out.sum().backward()
_x.grad = None
_lA.grad = None
_lB.grad = None
t_full = _bench(_run_autograd)
_clean()
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated()
_run_autograd()
torch.cuda.synchronize()
mem_peak = torch.cuda.max_memory_allocated() - mem_before
print(
f" full_fwd_bwd={t_full:>6.2f}ms "
f"peak_delta={mem_peak / 1e6:>6.1f}MB"
)
print()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,191 @@
"""Benchmark for selective_log_softmax Triton kernel vs original implementation.
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_selective_logsoftmax.py
"""
import gc
import statistics
import torch
from axolotl.monkeypatch.trainer.utils import (
selective_log_softmax,
selective_log_softmax_original,
)
V = 151936 # Qwen vocab
WARMUP = 5
BENCH_ITERS = 20
MEM_ITERS = 10
def _clean_gpu():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.synchronize()
def profile_time(fn, args, n_iters=BENCH_ITERS):
for _ in range(WARMUP):
fn(*args)
torch.cuda.synchronize()
times = []
for _ in range(n_iters):
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
fn(*args)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
return times
def profile_memory(fn, args, n_iters=MEM_ITERS):
for _ in range(WARMUP):
out = fn(*args)
del out
torch.cuda.synchronize()
peaks = []
for _ in range(n_iters):
_clean_gpu()
base = torch.cuda.max_memory_allocated()
out = fn(*args)
torch.cuda.synchronize()
peaks.append(torch.cuda.max_memory_allocated() - base)
del out
return [p / 1e6 for p in peaks]
def fmt(values, unit=""):
mean = statistics.mean(values)
std = statistics.stdev(values) if len(values) > 1 else 0.0
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
def benchmark_forward():
print("=" * 60)
print(f"FORWARD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
print("=" * 60)
configs = [
(1, 2048),
(1, 8192),
(4, 4096),
(8, 2048),
(16, 2048),
(16, 4096),
]
for B, L in configs:
mem_gb = B * L * V * 2 / 1e9
if mem_gb > 28:
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
print(f"{'' * 60}")
torch.manual_seed(42)
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
index = torch.randint(0, V, (B, L), device="cuda")
t_orig = profile_time(selective_log_softmax_original, (logits, index))
t_triton = profile_time(selective_log_softmax, (logits, index))
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" TIME (ms):")
print(f" original: {fmt(t_orig, 'ms')}")
print(f" triton: {fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(selective_log_softmax_original, (logits, index))
m_triton = profile_memory(selective_log_softmax, (logits, index))
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" MEMORY (peak overhead):")
print(f" original: {fmt(m_orig, 'MB')}")
print(f" triton: {fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del logits, index
_clean_gpu()
def benchmark_backward():
print("\n" + "=" * 60)
print(f"FWD+BWD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
print("=" * 60)
configs = [
(1, 2048),
(1, 8192),
(4, 4096),
(8, 2048),
(16, 2048),
(16, 4096),
]
def fwd_bwd_original(logits, index):
logits.grad = None
out = selective_log_softmax_original(logits, index)
out.sum().backward()
def fwd_bwd_triton(logits, index):
logits.grad = None
out = selective_log_softmax(logits, index)
out.sum().backward()
for B, L in configs:
mem_gb = B * L * V * 2 / 1e9
if mem_gb > 20:
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB, need room for grads)")
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
print(f"{'' * 60}")
torch.manual_seed(42)
logits_orig = torch.randn(
B, L, V, device="cuda", dtype=torch.bfloat16, requires_grad=True
)
logits_tri = logits_orig.detach().clone().requires_grad_(True)
index = torch.randint(0, V, (B, L), device="cuda")
t_orig = profile_time(fwd_bwd_original, (logits_orig, index))
t_triton = profile_time(fwd_bwd_triton, (logits_tri, index))
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" FWD+BWD TIME (ms):")
print(f" original: {fmt(t_orig, 'ms')}")
print(f" triton: {fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(fwd_bwd_original, (logits_orig, index))
m_triton = profile_memory(fwd_bwd_triton, (logits_tri, index))
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" FWD+BWD MEMORY (peak overhead):")
print(f" original: {fmt(m_orig, 'MB')}")
print(f" triton: {fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del logits_orig, logits_tri, index
_clean_gpu()
if __name__ == "__main__":
benchmark_forward()
benchmark_backward()

View File

@@ -11,7 +11,7 @@ ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace

View File

@@ -12,7 +12,7 @@ ENV HF_HOME="{{ HF_HOME }}"
ENV AXOLOTL_DATASET_NUM_PROC="8"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace

View File

@@ -3,11 +3,13 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
# curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
hf download "NousResearch/Meta-Llama-3-8B"
hf download "NousResearch/Meta-Llama-3-8B-Instruct"
hf download "microsoft/Phi-4-reasoning"
hf download "microsoft/Phi-3.5-mini-instruct"
set -o pipefail
curl --silent --show-error --fail --retry 3 --retry-delay 5 -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
# hf download "NousResearch/Meta-Llama-3-8B"
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
# hf download "microsoft/Phi-4-reasoning"
# hf download "microsoft/Phi-3.5-mini-instruct"
# hf download "microsoft/Phi-3-medium-128k-instruct"
# Run unit tests with initial coverage report
pytest -v --durations=10 -n8 \

View File

@@ -68,10 +68,6 @@ def run_cmd(cmd: str, run_folder: str):
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
# Propagate errors from subprocess.
try:
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
if exit_code:
print(f"Command '{cmd}' failed with exit code {exit_code}")
return exit_code
except Exception as e: # pylint: disable=broad-except
print(f"Command '{cmd}' failed with exception {e}")
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
if exit_code:
raise RuntimeError(f"Command '{cmd}' failed with exit code {exit_code}")

View File

@@ -37,6 +37,7 @@ coverage:
only_pulls: false
flags: null
paths: null
informational: true
parsers:
gcov:

View File

@@ -13,9 +13,10 @@ sdp_attention: true
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
## Flash Attention 2
## Flash Attention
Uses efficient kernels to compute attention.
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
based on your installed packages and GPU.
```yaml
flash_attention: true
@@ -23,11 +24,9 @@ flash_attention: true
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
### Nvidia
### Flash Attention 2
Requirements: Ampere, Ada, or Hopper GPUs
Note: For Turing GPUs or lower, please use other attention methods.
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
```bash
pip install flash-attn --no-build-isolation
@@ -35,11 +34,12 @@ pip install flash-attn --no-build-isolation
::: {.callout-tip}
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl.
Alternatively, try reinstall or downgrade a version.
:::
#### Flash Attention 3
### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
@@ -50,6 +50,44 @@ cd flash-attention/hopper
python setup.py install
```
### Flash Attention 4
Requirements: Hopper or Blackwell GPUs
```bash
pip install flash-attn-4
```
Or from source:
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/flash_attn/cute
pip install -e .
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
# Remove it so Python can find the real FA4 module:
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute
```
::: {.callout-note}
**Hopper (SM90) users**: The backward kernel is not yet included in the pip package. To use FA4
for training on Hopper, install from source using the instructions above.
:::
::: {.callout-warning}
FA4 only supports head dimensions up to 128 (`d ≤ 128`). The DeepSeek shape `(192, 128)` is
also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions
and falls back to FA2/3.
:::
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
### AMD
Requirements: ROCm 6.0 and above.

View File

@@ -1,5 +1,5 @@
---
title: Gradient Checkpointing and Activation Offloading
title: Gradient Checkpointing, Activation Offloading, and Layer Offloading
---
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
@@ -27,3 +27,33 @@ The `activation_offloading: legacy` naively offloads activations to CPU and with
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.
### Enabling Layer Offloading
```yaml
layer_offloading: true
```
Layer offloading reduces GPU memory usage by moving frozen (non-trainable) decoder layer parameters to CPU
and streaming them back to GPU one layer at a time during the forward and backward passes. This is
particularly useful for LoRA/QLoRA training where most of the model's parameters are frozen — only the
trainable adapter weights stay on GPU permanently.
During training, forward and backward hooks on each decoder layer handle the transfer automatically:
- **Forward pass:** Before a layer executes, its frozen params are loaded to GPU. The next layer is
prefetched asynchronously on a separate CUDA stream for overlap.
- **Backward pass:** Same pattern in reverse — the current layer's frozen params are loaded and the
previous layer is prefetched.
After each layer finishes, its frozen params are offloaded back to CPU pinned memory.
This approach trades some CPU-GPU transfer overhead for significant GPU memory savings — the freed memory
is roughly equal to the size of all frozen parameters across all decoder layers, minus one layer's worth
that is kept on GPU at any given time.
**Requirements:**
- CUDA GPU (CPU-only training is not supported for this feature)
- Works with any HuggingFace model architecture that uses decoder layers (Llama, Mistral, Qwen, etc.)
- Best combined with LoRA/QLoRA where most parameters are frozen

View File

@@ -13,12 +13,14 @@ format:
- [Pixtral](#sec-pixtral)
- [Llava-1.5](#sec-llava-15)
- [Mistral-Small-3.1](#sec-mistral-small-31)
- [Mistral-Small-4](#sec-mistral-small-4)
- [Magistral-Small-2509](#sec-magistral-small-2509)
- [Voxtral](#sec-voxtral)
- [Gemma-3](#sec-gemma-3)
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [Qwen3.5](#sec-qwen3-5)
- [GLM-4.6V](#sec-glm-4-6v)
- [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl)
@@ -108,6 +110,12 @@ Please make sure to install vision lib via `pip install 'mistral-common[opencv]=
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
```
### Mistral-Small-4 {#sec-mistral-small-4}
```yaml
base_model: mistralai/Mistral-Small-4-119B-2603
```
### Magistral-Small-2509 {#sec-magistral-small-2509}
::: {.callout-tip}
@@ -184,6 +192,14 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### Qwen3.5 {#sec-qwen3-5}
```yaml
base_model: Qwen/Qwen3.5-9B
chat_template: qwen3_5
```
### GLM-4.6V {#sec-glm-4-6v}
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.

View File

@@ -54,6 +54,13 @@ These techniques save VRAM by changing how activations are handled.
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
### Layer Offloading
Offloads frozen (non-trainable) decoder layer parameters to CPU and streams them back to GPU one layer at a time during forward/backward passes using CUDA stream prefetching. Especially effective for LoRA/QLoRA where most parameters are frozen.
- **Config:** `layer_offloading: true`
- **Learn more:** [Layer Offloading Docs](gradient_checkpointing.qmd#enabling-layer-offloading)
### Cut Cross Entropy (CCE)
Reduces VRAM usage by using an optimized cross-entropy loss calculation.

View File

@@ -721,6 +721,213 @@ trl:
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
#### Async GRPO
Async GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step.
```yaml
trl:
use_data_producer: true # Enable data producer protocol
use_vllm: true
async_prefetch: true # Generate rollouts in background thread
prefetch_depth: 1 # Number of rollouts to prefetch
vllm_sync_interval: 2 # Sync weights to vLLM every N steps
```
::: {.callout-note}
Because the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by `vllm_importance_sampling_correction: true` (default when async is enabled).
:::
##### vLLM LoRA Sync
By default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels.
```yaml
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true
trl:
vllm_lora_sync: true # Enable native LoRA sync
```
When `vllm_lora_sync: true` is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual:
```bash
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
```
Then start training on a separate GPU:
```bash
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
```
::: {.callout-tip}
LoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation.
:::
##### Streaming Partial Batch
Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring.
```yaml
trl:
streaming_partial_batch: true
```
##### Importance Sampling Correction
When using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift.
```yaml
trl:
vllm_importance_sampling_correction: true # Enable IS correction
importance_sampling_level: token # 'token' or 'sequence'
off_policy_mask_threshold: 0.5 # Mask sequences with IS ratio below this
```
- `importance_sampling_level: token` applies per-token IS ratios (recommended with Liger kernel)
- `importance_sampling_level: sequence` applies per-sequence IS ratios
- `off_policy_mask_threshold` masks out sequences where the IS ratio indicates they are too far off-policy
##### Replay Buffer
The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches.
```yaml
trl:
replay_buffer_size: 100 # Max cached groups (0 = disabled)
replay_recompute_logps: true # Recompute log-probs for replayed data (recommended)
```
::: {.callout-note}
When `replay_recompute_logps: true` (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data.
:::
##### Deferred Re-rolling
Failed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them.
```yaml
trl:
reroll_start_fraction: 0.5 # Start re-rolling after 50% of training
reroll_max_groups: 1 # Max groups to replace per batch
```
##### Zero-Advantage Batch Skipping
When all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as `skipped_zero_adv_batches=1`.
```yaml
trl:
skip_zero_advantage_batches: true # default
```
##### Parallel Reward Workers
Reward functions that use `signal.alarm()` (e.g., `math_verify`) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation.
```yaml
trl:
reward_num_workers: 4 # Number of subprocess workers (1 = no parallelism)
```
##### Full Async GRPO Example
```yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
gpu_memory_utilization: 0.35
dtype: auto
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true
rl: grpo
trl:
use_data_producer: true
use_vllm: true
async_prefetch: true
prefetch_depth: 1
vllm_sync_interval: 2
vllm_lora_sync: true
streaming_partial_batch: true
vllm_importance_sampling_correction: true
off_policy_mask_threshold: 0.5
importance_sampling_level: token
num_generations: 8
max_completion_length: 512
reward_funcs:
- rewards.accuracy_reward
reroll_start_fraction: 0.5
replay_buffer_size: 100
reward_num_workers: 4
skip_zero_advantage_batches: true
datasets:
- path: AI-MO/NuminaMath-TIR
type: rewards.prompt_transform
split: train
gradient_accumulation_steps: 4
micro_batch_size: 2
max_steps: 500
learning_rate: 1e-5
bf16: true
gradient_checkpointing: true
```
```bash
# Terminal 1: Start vLLM on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
# Terminal 2: Train on GPU 1
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
```
##### Multi-GPU Async GRPO
Async GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs.
**FSDP:**
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
gradient_checkpointing_kwargs:
use_reentrant: false
```
**DeepSpeed ZeRO-3:**
```yaml
deepspeed: deepspeed_configs/zero3_bf16.json
gradient_checkpointing_kwargs:
use_reentrant: true # Required for ZeRO-3
```
```bash
# Terminal 1: Start vLLM on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
# Terminal 2: Train on GPUs 0,1
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --num_processes 2 -m axolotl.cli.train config.yaml
```
::: {.callout-important}
With multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads.
:::
### GDPO
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
]
},
{

View File

@@ -1,8 +1,5 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -27,6 +24,11 @@ datasets:
val_set_size: 0.0
output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
adapter: qlora
lora_r: 32
lora_alpha: 16

View File

@@ -1,8 +1,5 @@
base_model: google/gemma-3-270m-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -27,6 +24,11 @@ datasets:
val_set_size: 0.0
output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
adapter: qlora
lora_r: 32
lora_alpha: 16

View File

@@ -1,9 +1,5 @@
base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
load_in_4bit: true
# gemma3 doesn't seem to play nice with ddp
@@ -24,6 +20,11 @@ dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
adapter: qlora
lora_model_dir:

View File

@@ -0,0 +1,82 @@
# Finetune Mistral Small 4 with Axolotl
Mistral Small 4 is a 119B parameter (6.5B active) multimodal MoE model from MistralAI that unifies instruct, reasoning, and coding capabilities into a single model. It is available on HuggingFace at [Mistral-Small-4-119B-2603](https://huggingface.co/mistralai/Mistral-Small-4-119B-2603).
Thanks to the team at MistralAI for giving us early access to prepare for this release.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
3. Install transformers from main
```bash
pip install git+https://github.com/huggingface/transformers.git
```
4. Run one of the example configs:
```bash
# text-only
axolotl train examples/mistral4/qlora-text.yml # no experts ~69 GiB, experts ~93 GiB
axolotl train examples/mistral4/fft-text.yml
# text + vision
# run: wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
axolotl train examples/mistral4/qlora-vision.yml # no experts ~68 GiB
axolotl train examples/mistral4/fft-vision.yml
```
Note: FFT configs provided as reference. Please adjust hyperparameters as needed.
## Reasoning Effort
The chat template supports a `reasoning_effort` variable to control the model's reasoning depth:
- `"none"` — instruct mode (default)
- `"high"` — reasoning mode with explicit thinking steps
Pass it via `chat_template_kwargs` under your dataset config:
```yaml
datasets:
- path: your/dataset
type: chat_template
chat_template_kwargs:
reasoning_effort: high
```
## Thinking Support
The chat template supports a `thinking` content type in assistant messages for training on reasoning traces (rendered as `[THINK]...[/THINK]` blocks).
To use thinking datasets, add the `thinking` mapping via `message_property_mappings`:
```yaml
datasets:
- path: your/thinking-dataset
type: chat_template
message_property_mappings:
role: role
content: content
thinking: thinking
chat_template_kwargs:
reasoning_effort: high
```
See the [Magistral thinking guide](../magistral/think/README.md) for dataset format details.
## Tips
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
## Related Resources
- [MistralAI Mistral Small 4 Blog](https://mistral.ai/news/mistral-small-4)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,58 @@
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_sonicmoe: true
# only train language model layers, freeze vision tower
unfrozen_parameters:
- model.language_model.*
- lm_head
- embed_tokens
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: false
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Mistral4DecoderLayer
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -0,0 +1,57 @@
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
processor_type: AutoProcessor
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_sonicmoe: true
# vision requirements
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: Nanobit/text-vision-2k-test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
sequence_len: 2048
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: false
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Mistral4DecoderLayer
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -0,0 +1,58 @@
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
# uncomment to train on expert layers
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# lora_mlp_kernel: false
# lora_qkv_kernel: false
# lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -0,0 +1,63 @@
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
processor_type: AutoProcessor
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
# vision chat template requirements
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: Nanobit/text-vision-2k-test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
sequence_len: 2048
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
# uncomment to train on expert layers
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# lora_mlp_kernel: false
# lora_qkv_kernel: false
# lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -0,0 +1,57 @@
base_model: nvidia/Nemotron-Mini-4B-Instruct
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/nemotron-mini-4b-qlora
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- up_proj
- down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
special_tokens:

View File

@@ -0,0 +1,84 @@
base_model: Qwen/Qwen3.5-122B-A10B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
fsdp_config:
fsdp_version: 2
offload_params: true
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -32,7 +32,11 @@ lora_target_modules:
- v_proj
- o_proj
#lora_target_parameters:
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
@@ -52,7 +56,6 @@ learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false

View File

@@ -0,0 +1,59 @@
base_model: Qwen/Qwen3.5-27B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# Full fine-tune (FFT) of the text-only path of Qwen3.5-27B.
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
# Freeze vision encoder
unfrozen_parameters:
- model\.language_model\..*
- lm_head\..*
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -0,0 +1,81 @@
base_model: Qwen/Qwen3.5-27B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
# Uncomment below to also target the linear attention projections.
# These use separate in_proj_qkv / in_proj_z / out_proj (Qwen3.5-specific).
# - linear_attn.in_proj_qkv
# - linear_attn.in_proj_z
# - linear_attn.out_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
fsdp_config:
fsdp_version: 2
offload_params: false
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3_5DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -1,9 +1,7 @@
base_model: Qwen/Qwen3.5-27B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# Note: Qwen3.5 is an early-fusion VLM (image+text). This config fine-tunes
# the text-only path. For multimodal (image+text) fine-tuning, add image
# columns to your dataset following axolotl's multimodal dataset format.
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

View File

@@ -0,0 +1,85 @@
base_model: Qwen/Qwen3.5-35B-A3B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
fsdp_config:
fsdp_version: 2
offload_params: true
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -32,7 +32,11 @@ lora_target_modules:
- v_proj
- o_proj
#lora_target_parameters:
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj

View File

@@ -0,0 +1,49 @@
base_model: Qwen/Qwen3.5-9B
processor_type: AutoProcessor
# Required for multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: qwen3_5
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -1,10 +1,6 @@
base_model: Qwen/Qwen3.5-7B
base_model: Qwen/Qwen3.5-9B
processor_type: AutoProcessor
# Qwen3.5-7B and above are early-fusion VLMs (Qwen3_5ForConditionalGeneration).
# Vision and text tokens are processed together by the same transformer layers.
# Note: Qwen3.5-2B is a text-only model — the smallest VLM is Qwen3.5-7B.
# These 3 lines are required for vision/multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
@@ -30,8 +26,6 @@ lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
# Targets the language model attention and MLP layers.
# Qwen3.5 is early-fusion: all layers (including those seeing vision tokens) share
# the same transformer stack, so standard attention targets work for both modalities.
lora_target_modules:
- q_proj
- k_proj

View File

@@ -1,15 +1,6 @@
# Finetune Qwen3.5 with Axolotl
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-68452f3bc6e4b7cfb4e1c803) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. Models from 7B onwards are early-fusion vision-language models (`Qwen3_5ForConditionalGeneration`), meaning vision and text tokens are processed through the same transformer stack. The 2B variant is text-only.
Available configs:
| Config | Model | Type |
|---|---|---|
| `27b-qlora.yaml` | Qwen3.5-27B | Dense VLM, text-only path |
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only path |
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only path |
| `7b-lora-vision.yaml` | Qwen3.5-7B | Vision+text (multimodal) |
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. All Qwen3.5 models are early-fusion vision-language models: dense variants use `Qwen3_5ForConditionalGeneration` and MoE variants use `Qwen3_5MoeForConditionalGeneration`.
## Getting started
@@ -18,35 +9,69 @@ Available configs:
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
```
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
4. Pick any config from the table below and run:
```bash
axolotl train examples/qwen3.5/<config>.yaml
```
Available configs:
| Config | Model | Type | Peak VRAM |
|---|---|---|---|
| `9b-lora-vision.yaml` | Qwen3.5-9B | Vision+text LoRA, single GPU | — |
| `9b-fft-vision.yaml` | Qwen3.5-9B | Vision+text FFT, single GPU | ~61 GiB |
| `27b-qlora.yaml` | Qwen3.5-27B | Dense, text-only QLoRA | ~47 GiB |
| `27b-fft.yaml` | Qwen3.5-27B | Dense, text-only FFT (vision frozen) | ~53 GiB |
| `27b-qlora-fsdp.yaml` | Qwen3.5-27B | Dense, text-only QLoRA + FSDP2 | — |
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA | — |
| `35b-a3b-moe-qlora-fsdp.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA + FSDP2 | — |
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA | — |
| `122b-a10b-moe-qlora-fsdp.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA + FSDP2 | — |
### Gated DeltaNet Linear Attention
Qwen3.5 interleaves standard attention with Gated DeltaNet linear attention layers. To apply LoRA to them, add to `lora_target_modules`:
```yaml
lora_target_modules:
# ... standard projections ...
- linear_attn.in_proj_qkv
- linear_attn.in_proj_z
- linear_attn.out_proj
```
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
4. Run a finetuning example:
### Routed Experts (MoE)
```bash
# Dense 27B text-only (QLoRA, ~47 GiB VRAM with sample packing)
axolotl train examples/qwen3.5/27b-qlora.yaml
To apply LoRA to routed expert parameters, add `lora_target_parameters`:
# MoE 35B-A3B text-only (QLoRA)
axolotl train examples/qwen3.5/35b-a3b-moe-qlora.yaml
```yaml
lora_target_parameters:
- mlp.experts.gate_up_proj
- mlp.experts.down_proj
# - mlp.gate.weight # router
```
# MoE 122B-A10B text-only (QLoRA)
axolotl train examples/qwen3.5/122b-a10b-moe-qlora.yaml
### Shared Experts (MoE)
# 7B vision+text (LoRA, multimodal dataset)
axolotl train examples/qwen3.5/7b-lora-vision.yaml
Routed experts and shared experts both have `gate_up_proj`/`down_proj`, so a plain module name in `lora_target_modules` would match both. Use a regex to target only attention and shared expert projections, while `lora_target_parameters` above handles routed experts separately:
```yaml
lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
```
### TIPS
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
- You can run a full finetuning by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
- For inference hyp, please see the respective model card details.
- You can run a full finetuning of smaller configs by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `7b-lora-vision.yaml`.
- The Gated DeltaNet linear attention layers (`linear_attn.*`) can optionally be added to `lora_target_modules` — they are commented out by default.
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `9b-lora-vision.yaml`.
## Optimization Guides

View File

@@ -61,5 +61,11 @@ skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false
[tool.pytest.ini_options]
addopts = "-m 'not slow'"
markers = [
"slow: marks tests as slow",
]
[tool.uv.extra-build-dependencies]
axolotl = ["huggingface_hub"]

View File

@@ -75,4 +75,4 @@ axolotl-contribs-mit==0.0.6
# telemetry
posthog==6.7.11
mistral-common==1.8.8
mistral-common==1.10.0

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"'
)

View File

@@ -81,16 +81,23 @@ def parse_requirements(extras_require_map):
f"https://download.pytorch.org/whl/{torch_cuda_version}"
)
if (major, minor) >= (2, 9):
if (major, minor) >= (2, 10):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [
"fbgemm-gpu==1.5.0",
"fbgemm-gpu-genai==1.5.0",
]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm==0.17.1"]
elif (major, minor) >= (2, 9):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [
"fbgemm-gpu==1.4.0",
"fbgemm-gpu-genai==1.4.2",
]
extras_require_map["vllm"] = ["vllm==0.11.1"]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm==0.13.0"]
if patch == 0:
extras_require_map["vllm"] = ["vllm==0.13.0"]
else:

View File

@@ -3,6 +3,7 @@
import os
from pathlib import Path
import httpcore
from accelerate.commands.config import config_args
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
@@ -47,7 +48,7 @@ def check_user_token() -> bool:
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False
except HTTPError:
except (HTTPError, httpcore.ConnectError):
LOG.warning(
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
)

View File

@@ -90,9 +90,8 @@ class ModalCloud(Cloud):
# grab the sha256 hash from docker hub for this image+tag
# this ensures that we always get the latest image for this tag, even if it's already cached
try:
manifest = subprocess.check_output( # nosec B602
f"docker manifest inspect {docker_image}",
shell=True,
manifest = subprocess.check_output( # nosec
["docker", "manifest", "inspect", docker_image],
).decode("utf-8")
sha256_hash = json.loads(manifest)["manifests"][0]["digest"]
except subprocess.CalledProcessError:

View File

@@ -11,7 +11,7 @@ from urllib.parse import urlparse
import requests
import torch
import yaml
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils import is_torch_bf16_gpu_available, is_torch_tf32_available
from axolotl.integrations.base import PluginManager
from axolotl.telemetry.errors import send_errors
@@ -300,7 +300,7 @@ def load_cfg(
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except:
except (RuntimeError, AssertionError):
gpu_version = None
prepare_plugins(cfg)
@@ -310,6 +310,7 @@ def load_cfg(
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"fp8": compute_supports_fp8(),
"tf32": is_torch_tf32_available(),
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},

View File

@@ -196,12 +196,10 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
state.wait_for_everyone()
LOG.info(
f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}",
main_process_only=True,
)
LOG.info(
"Merged weights are only the safetensors and doesn't include the model configuration "
f"or tokenizer which may be found in {parsed_cfg.output_dir}.",
main_process_only=True,
)

View File

@@ -38,7 +38,18 @@ def do_vllm_serve(
cfg = load_cfg(config)
model = cfg.base_model
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
# Determine serve module: explicit CLI/config > auto-select from vllm_lora_sync > default
serve_module = cli_args.get("serve_module") or getattr(
cfg.vllm, "serve_module", None
)
if (
serve_module is None
and getattr(cfg, "trl", None)
and getattr(cfg.trl, "vllm_lora_sync", False)
):
serve_module = "axolotl.scripts.vllm_serve_lora"
if serve_module is None:
serve_module = "trl.scripts.vllm_serve"
vllm_serve_main = __import__(serve_module, fromlist=["main"]).main
tensor_parallel_size = 1
data_parallel_size = 1
@@ -68,7 +79,7 @@ def do_vllm_serve(
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
)
vllm_script_args = AxolotlScriptArguments(
base_kwargs = dict(
model=model,
tensor_parallel_size=tensor_parallel_size,
data_parallel_size=data_parallel_size,
@@ -78,7 +89,21 @@ def do_vllm_serve(
dtype=dtype,
max_model_len=max_model_len,
enable_prefix_caching=enable_prefix_caching,
reasoning_parser=reasoning_parser,
enable_reasoning=enable_reasoning,
)
# Use LoRAScriptArguments when serving with native LoRA support
if serve_module == "axolotl.scripts.vllm_serve_lora":
from axolotl.scripts.vllm_serve_lora import LoRAScriptArguments
lora_kwargs = {}
if hasattr(cfg, "lora_r") and cfg.lora_r:
lora_kwargs["max_lora_rank"] = cfg.lora_r
vllm_script_args = LoRAScriptArguments(**base_kwargs, **lora_kwargs)
else:
vllm_script_args = AxolotlScriptArguments(
**base_kwargs,
reasoning_parser=reasoning_parser,
enable_reasoning=enable_reasoning,
)
vllm_serve_main(vllm_script_args)

View File

@@ -16,6 +16,7 @@ MOE_ARCH_BLOCK = {
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
"deepseek_v3": "DeepseekV3MoE",
"mistral4": "Mistral4MoE",
"gpt_oss": "GptOssDecoderLayer",
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
"afmoe": "AfmoeMoE",

View File

@@ -67,7 +67,7 @@ class JsonToJsonlConverter:
self.json_parser = json_parser
self.jsonl_serializer = jsonl_serializer
def convert(self, input_file_path, output_file_path):
def convert(self, input_file_path):
content = self.file_reader.read(input_file_path)
data = self.json_parser.parse(content)
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations

View File

@@ -250,7 +250,7 @@ class TrainerBuilderBase(abc.ABC):
def _configure_precision_settings(self, training_args_kwargs: dict):
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
training_args_kwargs["tf32"] = self.cfg.tf32
training_args_kwargs["tf32"] = True if self.cfg.tf32 is True else False
if self.cfg.bf16 == "full":
training_args_kwargs["bf16_full_eval"] = True
else:
@@ -353,6 +353,30 @@ class TrainerBuilderBase(abc.ABC):
adam_kwargs["eps"] = (eps1, eps2)
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "flash_adamw":
from flashoptim import FlashAdamW
optimizer_cls = FlashAdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "flash_adam":
from flashoptim import FlashAdam
optimizer_cls = FlashAdam
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "flash_sgd":
from flashoptim import FlashSGD
optimizer_cls = FlashSGD
elif self.cfg.optimizer == "flash_sgdw":
from flashoptim import FlashSGDW
optimizer_cls = FlashSGDW
elif self.cfg.optimizer == "flash_lion":
from flashoptim import FlashLion
optimizer_cls = FlashLion
if "betas" in adam_kwargs:
optimizer_kwargs["betas"] = adam_kwargs["betas"]
else:
raise ValueError(
f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue."
@@ -484,6 +508,8 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.layer_offloading:
training_args_kwargs["layer_offloading"] = True
if self.cfg.activation_offloading is True:
# don't use the HF gradient checkpointing, manually wrap
training_args_kwargs["gradient_checkpointing"] = False

View File

@@ -421,6 +421,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
# TRL's RewardTrainer validates num_labels=1 on pre-loaded models; ensure the
# config reflects this regardless of how the model was instantiated.
if (
self.cfg.reward_model
and getattr(self.model.config, "num_labels", None) != 1
):
self.model.config.num_labels = 1
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,

View File

@@ -54,8 +54,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
async_grpo = bool(
self.cfg.trl
and (
getattr(self.cfg.trl, "async_prefetch", False)
or getattr(self.cfg.trl, "use_data_producer", False)
)
)
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1
sequence_parallel=self.cfg.context_parallel_size > 1,
async_grpo=async_grpo,
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
@@ -151,7 +159,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
training_args_cls = GRPOStrategy.get_training_args_class()
async_grpo = bool(
self.cfg.trl
and (
getattr(self.cfg.trl, "async_prefetch", False)
or getattr(self.cfg.trl, "use_data_producer", False)
)
)
training_args_cls = GRPOStrategy.get_training_args_class(
async_grpo=async_grpo
)
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
if self.cfg.rl is RLType.GDPO:
@@ -191,7 +208,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO:
if (
self.cfg.adapter
and self.peft_config
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO)
):
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = (
@@ -217,13 +238,36 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_kwargs, trainer_cls
)
trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
# Allow FP8-quantized models to be fine-tuned with LoRA adapters.
# transformers' validate_quantization_for_training blocks FP8 because
# hf_quantizer.is_trainable is False, but LoRA only trains the adapters
# (base weights stay frozen in FP8).
_orig_validate_quant = None
if (
self.cfg.adapter
and hasattr(self.model, "is_quantized")
and self.model.is_quantized
):
import transformers.trainer as _trainer_module
_orig_validate_quant = _trainer_module.validate_quantization_for_training
_trainer_module.validate_quantization_for_training = lambda model: None
try:
trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
finally:
if _orig_validate_quant is not None:
import transformers.trainer as _trainer_module
_trainer_module.validate_quantization_for_training = (
_orig_validate_quant
)
if self.cfg.fsdp_config or self.cfg.fsdp:
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:

View File

@@ -29,10 +29,12 @@ from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
from trl.experimental.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
from axolotl.core.trainers.mixins import (
ActivationOffloadingMixin,
CheckpointSaveMixin,
DistributedParallelMixin,
LayerOffloadingMixin,
OptimizerMixin,
PackingMixin,
RngLoaderMixin,
@@ -51,8 +53,6 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__)
TOKENS_STATE_FILE = "tokens_state."
REDUCTION_FNS = {
"mean": torch.mean,
"min": torch.min,
@@ -67,6 +67,7 @@ class AxolotlTrainer(
OptimizerMixin,
RngLoaderMixin,
CheckpointSaveMixin,
LayerOffloadingMixin,
ActivationOffloadingMixin,
DistributedParallelMixin,
Trainer,

View File

@@ -0,0 +1 @@
TOKENS_STATE_FILE = "tokens_state.json"

View File

@@ -2,7 +2,8 @@
Axolotl specific DPO args
"""
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional
from trl import DPOConfig
@@ -16,3 +17,4 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
dpo_norm_loss: bool | None = False
rpo_alpha: Optional[float] = field(default=None)

View File

@@ -9,8 +9,9 @@ from huggingface_hub import snapshot_download
from requests import HTTPError
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
from axolotl.core.trainers.grpo.args import AxolotlAsyncGRPOConfig, AxolotlGRPOConfig
from axolotl.core.trainers.grpo.trainer import (
AxolotlAsyncGRPOTrainer,
AxolotlGRPOSequenceParallelTrainer,
AxolotlGRPOTrainer,
)
@@ -27,14 +28,31 @@ class GRPOStrategy:
@classmethod
def get_trainer_class(
cls, sequence_parallel: bool
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]:
cls,
sequence_parallel: bool,
async_grpo: bool = False,
) -> (
type[AxolotlGRPOTrainer]
| type[AxolotlGRPOSequenceParallelTrainer]
| type[AxolotlAsyncGRPOTrainer]
):
if sequence_parallel and async_grpo:
raise ValueError(
"sequence_parallel and async_grpo cannot both be enabled. "
"Disable one of context_parallel_size > 1 or async_prefetch/use_data_producer."
)
if sequence_parallel:
return AxolotlGRPOSequenceParallelTrainer
if async_grpo:
return AxolotlAsyncGRPOTrainer
return AxolotlGRPOTrainer
@classmethod
def get_training_args_class(cls) -> type[AxolotlGRPOConfig]:
def get_training_args_class(
cls, async_grpo: bool = False
) -> type[AxolotlGRPOConfig] | type[AxolotlAsyncGRPOConfig]:
if async_grpo:
return AxolotlAsyncGRPOConfig
return AxolotlGRPOConfig
@classmethod
@@ -124,13 +142,63 @@ class GRPOStrategy:
grpo_args_kwargs["epsilon_high"] = trl.epsilon_high
if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
grpo_args_kwargs["use_liger_kernel"] = trl.use_liger_loss
if trl.multi_objective_aggregation is not None:
grpo_args_kwargs["multi_objective_aggregation"] = (
trl.multi_objective_aggregation
)
# Async GRPO fields
if getattr(trl, "use_data_producer", None) is not None:
grpo_args_kwargs["use_data_producer"] = trl.use_data_producer
if getattr(trl, "async_prefetch", None) is not None:
grpo_args_kwargs["async_prefetch"] = trl.async_prefetch
if getattr(trl, "prefetch_depth", None) is not None:
grpo_args_kwargs["prefetch_depth"] = trl.prefetch_depth
if getattr(trl, "vllm_sync_interval", None) is not None:
grpo_args_kwargs["vllm_sync_interval"] = trl.vllm_sync_interval
if getattr(trl, "streaming_partial_batch", None) is not None:
grpo_args_kwargs["streaming_partial_batch"] = trl.streaming_partial_batch
if getattr(trl, "streaming_min_groups", None) is not None:
grpo_args_kwargs["streaming_min_groups"] = trl.streaming_min_groups
if getattr(trl, "vllm_importance_sampling_correction", None) is not None:
grpo_args_kwargs["vllm_importance_sampling_correction"] = (
trl.vllm_importance_sampling_correction
)
if getattr(trl, "vllm_importance_sampling_mode", None) is not None:
grpo_args_kwargs["vllm_importance_sampling_mode"] = (
trl.vllm_importance_sampling_mode
)
if getattr(trl, "vllm_importance_sampling_cap", None) is not None:
grpo_args_kwargs["vllm_importance_sampling_cap"] = (
trl.vllm_importance_sampling_cap
)
if getattr(trl, "off_policy_mask_threshold", None) is not None:
grpo_args_kwargs["off_policy_mask_threshold"] = (
trl.off_policy_mask_threshold
)
if getattr(trl, "use_bias_correction_kl", None) is not None:
grpo_args_kwargs["use_bias_correction_kl"] = trl.use_bias_correction_kl
# Fast Async GRPO fields
if getattr(trl, "reward_num_workers", None) is not None:
grpo_args_kwargs["reward_num_workers"] = trl.reward_num_workers
if getattr(trl, "replay_buffer_size", None) is not None:
grpo_args_kwargs["replay_buffer_size"] = trl.replay_buffer_size
if getattr(trl, "replay_recompute_logps", None) is not None:
grpo_args_kwargs["replay_recompute_logps"] = trl.replay_recompute_logps
if getattr(trl, "reroll_start_fraction", None) is not None:
grpo_args_kwargs["reroll_start_fraction"] = trl.reroll_start_fraction
if getattr(trl, "reroll_max_groups", None) is not None:
grpo_args_kwargs["reroll_max_groups"] = trl.reroll_max_groups
if getattr(trl, "skip_zero_advantage_batches", None) is not None:
grpo_args_kwargs["skip_zero_advantage_batches"] = (
trl.skip_zero_advantage_batches
)
if getattr(trl, "vllm_lora_sync", None) is not None:
grpo_args_kwargs["vllm_lora_sync"] = trl.vllm_lora_sync
return grpo_args_kwargs
@classmethod

View File

@@ -6,6 +6,7 @@ from dataclasses import dataclass
from trl import GRPOConfig
from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOConfig
from axolotl.core.training_args import AxolotlTrainingMixins
@@ -14,3 +15,10 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""Axolotl GRPO Config for GRPO training"""
context_parallel_size: int | None = None
@dataclass
class AxolotlAsyncGRPOConfig(AxolotlTrainingMixins, FastAsyncGRPOConfig):
"""Axolotl Async GRPO Config — adds async prefetch, streaming scoring, and IS correction."""
context_parallel_size: int | None = None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,768 @@
# Copyright 2020-2026 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Experimental GRPO extensions: parallel reward workers, replay buffer,
deferred re-roll, and zero-advantage skipping.
These features are built as subclasses of GRPOTrainer and GRPODataProducer,
using the hook system (_compute_rewards_for_batch, _post_advantage_hook,
_pre_produce_hook) defined in the base classes.
"""
from __future__ import annotations
import asyncio
import logging
import threading
from dataclasses import dataclass, field
import torch
from torch import nn
from trl import GRPOTrainer
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOConfig,
AsyncGRPOTrainer,
GRPODataProducer,
)
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Extended config
# ---------------------------------------------------------------------------
@dataclass
class FastAsyncGRPOConfig(AsyncGRPOConfig):
"""GRPOConfig with additional experimental parameters."""
reward_num_workers: int = field(
default=1,
metadata={
"help": "Number of persistent subprocess workers for parallel reward computation. Each worker has its "
"own main thread so signal.alarm() (used by math_verify) works correctly. Work is sharded across "
"workers by prompt groups. Only used with use_data_producer=True and non-nn.Module reward functions."
},
)
replay_buffer_size: int = field(
default=0,
metadata={
"help": "[Experimental, disabled by default] Size of the replay buffer for storing high-signal rollout "
"groups. When > 0, groups with reward variance are cached and used to replace zero-signal groups "
"(where all rewards are identical). Set to 0 to disable. Only used with use_data_producer=True."
},
)
replay_recompute_logps: bool = field(
default=True,
metadata={
"help": "When True (default), recompute old_per_token_logps for replayed groups using the current "
"training model. This fixes the importance sampling mismatch that occurs when replaying stale data. "
"Only relevant when replay_buffer_size > 0."
},
)
reroll_start_fraction: float = field(
default=0.5,
metadata={
"help": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts "
"(where all rewards in a group are identical) are buffered and re-injected into later batches when the "
"model is more likely to solve them. Set to 1.0 to disable. Only used with use_data_producer=True."
},
)
reroll_max_groups: int = field(
default=1,
metadata={
"help": "Maximum number of prompt groups to replace with re-roll candidates per batch. Higher values "
"increase data utilization but reduce prompt diversity. Only used with use_data_producer=True."
},
)
skip_zero_advantage_batches: bool = field(
default=True,
metadata={
"help": "When True, skip gradient computation for micro-batches where all advantages are zero (no learning "
"signal). This avoids the forward/backward pass entirely when no learning signal is present. The step is "
"logged with skipped_zero_adv_batches=1 for monitoring."
},
)
vllm_lora_sync: bool = field(
default=False,
metadata={
"help": "When True, sync LoRA adapter weights to vLLM via filesystem instead of merging into base model "
"and NCCL-broadcasting all parameters. vLLM loads the adapter natively using Punica kernels. "
"Requires vllm_serve_lora serve module (auto-selected when this is True). "
"Syncs only LoRA adapter weights (much smaller) vs full merged model. Legacy merge behavior is used when False."
},
)
# ---------------------------------------------------------------------------
# Extended data producer with re-roll injection
# ---------------------------------------------------------------------------
class RerollDataProducer(GRPODataProducer):
"""GRPODataProducer that injects re-roll candidates into prompt batches.
Reads from the trainer's ``_reroll_buffer`` (populated by
``GRPOExperimentalTrainer._post_advantage_hook``) and replaces the
last N prompt groups with previously-failed prompts.
"""
def _pre_produce_hook(self, inputs: list, global_step: int) -> list:
trainer = self._trainer
reroll_buf = getattr(trainer, "_reroll_buffer", None)
reroll_lock = getattr(trainer, "_reroll_lock", None)
if reroll_buf is None or reroll_lock is None:
return inputs
max_steps = getattr(trainer.args, "max_steps", -1)
start_frac = getattr(trainer.args, "reroll_start_fraction", 1.0)
max_groups = getattr(trainer.args, "reroll_max_groups", 1)
reroll_start_step = (
max(1, int(max_steps * start_frac)) if max_steps > 0 else float("inf")
)
if global_step < reroll_start_step:
return inputs
with reroll_lock:
n_to_take = min(max_groups, len(reroll_buf))
reroll_prompts = [reroll_buf.pop(0) for _ in range(n_to_take)]
if reroll_prompts:
num_gen = self._num_generations
n_groups = len(inputs) // num_gen
for i, reroll_prompt in enumerate(reroll_prompts):
group_idx = n_groups - 1 - i
if group_idx < 0:
break
start = group_idx * num_gen
for j in range(num_gen):
inputs[start + j] = reroll_prompt
logger.info(
f"[REROLL] Step {global_step}: replaced {len(reroll_prompts)}/{n_groups} prompt groups "
f"with deferred re-roll candidates ({len(reroll_buf)} remaining)"
)
return inputs
# ---------------------------------------------------------------------------
# Persistent reward subprocess pool
# ---------------------------------------------------------------------------
def _persistent_reward_worker(conn):
"""Long-lived reward worker. Receives work items, returns results."""
while True:
try:
msg = conn.recv()
except EOFError:
break
if msg is None: # Shutdown signal
break
(
reward_funcs,
prompts,
completions,
completion_ids_list,
inputs,
reward_func_names,
) = msg
try:
keys = [
key
for key in inputs[0]
if key not in ["prompt", "completion", "completion_ids"]
]
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
results = []
for reward_func, _reward_func_name in zip(
reward_funcs, reward_func_names, strict=True
):
output = reward_func(
prompts=prompts,
completions=completions,
completion_ids=completion_ids_list,
**reward_kwargs,
)
results.append(
[float(r) if r is not None else float("nan") for r in output]
)
conn.send(results)
except Exception:
conn.send(None)
# ---------------------------------------------------------------------------
# Extended trainer
# ---------------------------------------------------------------------------
class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
"""GRPOTrainer with experimental extensions.
Adds:
- Parallel reward subprocess workers (``reward_num_workers``)
- Replay buffer for high-signal group reuse (``replay_buffer_size``)
- Deferred re-roll of failed prompts (``reroll_start_fraction``)
- Zero-advantage micro-batch skipping
"""
def __init__(self, *args, **kwargs):
# These must be initialized before super().__init__() because
# _create_data_producer (called during super().__init__) needs them.
self._reroll_buffer: list = []
self._reroll_lock = threading.Lock()
# Temporarily suppress the base class's Liger + OPSM validation check,
# since this subclass supports it via a custom compute_liger_loss override.
grpo_args = kwargs.get("args")
if grpo_args is None:
for a in args:
if hasattr(a, "off_policy_mask_threshold"):
grpo_args = a
break
saved_threshold = None
if grpo_args is not None and getattr(grpo_args, "use_liger_kernel", False):
saved_threshold = grpo_args.off_policy_mask_threshold
grpo_args.off_policy_mask_threshold = None
super().__init__(*args, **kwargs)
if saved_threshold is not None:
grpo_args.off_policy_mask_threshold = saved_threshold
self.off_policy_mask_threshold = saved_threshold
# Replay buffer
if getattr(self.args, "replay_buffer_size", 0) > 0:
self._replay_buffer = ReplayBuffer(max_size=self.args.replay_buffer_size)
else:
self._replay_buffer = None
self._replay_recompute_logps = getattr(
self.args, "replay_recompute_logps", True
)
# Reward worker pool (lazy-initialized)
self._reward_workers = None
# -- Factory override: use RerollDataProducer ----------------------------
def _create_data_producer(self, args, train_dataset):
"""Override to use RerollDataProducer for re-roll prompt injection."""
from axolotl.core.trainers.grpo.async_trainer import (
AsyncDataProducer,
ProducerConfig,
)
producer_config = ProducerConfig(
mini_epochs=args.num_iterations,
max_rollouts=None,
eval_during_produce=False,
empty_cache_before_produce=True,
empty_cache_after_produce=True,
async_prefetch=args.async_prefetch,
prefetch_depth=args.prefetch_depth,
)
data_producer = RerollDataProducer(
config=producer_config,
prompt_dataset=train_dataset,
num_generations=self.num_generations,
generation_batch_size=args.generation_batch_size,
train_batch_size=args.per_device_train_batch_size,
steps_per_generation=args.steps_per_generation,
shuffle_dataset=self.shuffle_dataset,
seed=args.seed,
)
data_producer.set_trainer(self)
if args.async_prefetch:
data_producer = AsyncDataProducer(
data_producer,
background_produce_kwargs={"skip_policy_logps": True},
)
return data_producer
# -- Reward worker pool --------------------------------------------------
def _get_reward_workers(self):
"""Return a list of persistent reward worker subprocesses (lazy-initialized)."""
import multiprocessing as _mp
num_workers = getattr(self.args, "reward_num_workers", 1)
if num_workers < 1:
num_workers = 1
if self._reward_workers is not None:
alive = all(proc.is_alive() for conn, proc in self._reward_workers)
if alive and len(self._reward_workers) == num_workers:
return self._reward_workers
self._shutdown_reward_workers()
workers = []
for _ in range(num_workers):
parent_conn, child_conn = _mp.Pipe()
proc = _mp.Process(
target=_persistent_reward_worker, args=(child_conn,), daemon=True
)
proc.start()
child_conn.close()
workers.append((parent_conn, proc))
self._reward_workers = workers
return workers
def _shutdown_reward_workers(self):
"""Shut down all persistent reward workers."""
if self._reward_workers is None:
return
for conn, proc in self._reward_workers:
try:
conn.send(None)
proc.join(timeout=5)
except Exception:
pass
try:
conn.close()
except Exception:
pass
self._reward_workers = None
# -- Hook overrides ------------------------------------------------------
def _compute_rewards_for_batch(
self, inputs, prompts, completions, completion_ids_list
):
"""Dispatch rewards to parallel subprocess workers (synchronous wrapper)."""
self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
return self._collect_reward_workers(
inputs, prompts, completions, completion_ids_list
)
def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list):
"""Send reward work to subprocess workers (non-blocking).
Results are collected later by _collect_reward_workers, allowing GPU
logprob computation to overlap with CPU reward computation.
"""
reward_can_bg = all(
callable(rf)
and not isinstance(rf, nn.Module)
and not asyncio.iscoroutinefunction(rf)
for rf in self.reward_funcs
)
num_workers = getattr(self.args, "reward_num_workers", 1)
if not reward_can_bg or num_workers <= 1:
# Can't parallelize — store args for sync fallback in collect
self._reward_workers_used = None
self._pending_reward_args = (
inputs,
prompts,
completions,
completion_ids_list,
)
return
workers = self._get_reward_workers()
num_generations = self.num_generations
num_prompts = len(prompts)
num_groups = num_prompts // num_generations
# Shard by prompt groups across workers
groups_per_worker = max(1, (num_groups + len(workers) - 1) // len(workers))
workers_used = []
for w_idx, (conn, _proc) in enumerate(workers):
g_start = w_idx * groups_per_worker
g_end = min((w_idx + 1) * groups_per_worker, num_groups)
if g_start >= num_groups:
break
s_start = g_start * num_generations
s_end = g_end * num_generations
conn.send(
(
self.reward_funcs,
prompts[s_start:s_end],
completions[s_start:s_end],
completion_ids_list[s_start:s_end],
inputs[s_start:s_end],
self.reward_func_names,
)
)
workers_used.append(conn)
self._reward_workers_used = workers_used
self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)
def _collect_reward_workers(
self, inputs, prompts, completions, completion_ids_list
):
"""Collect reward results from subprocess workers (blocks until done)."""
from accelerate.utils import gather
workers_used = getattr(self, "_reward_workers_used", None)
args = getattr(self, "_pending_reward_args", None)
self._reward_workers_used = None
self._pending_reward_args = None
if workers_used is None:
# Sync fallback — compute on main thread
if args is not None:
return self._calculate_rewards(*args)
return self._calculate_rewards(
inputs, prompts, completions, completion_ids_list
)
device = self.accelerator.device
num_prompts = len(args[1]) if args else len(prompts)
# Collect results from workers
all_worker_results = []
any_failed = False
for conn in workers_used:
result = conn.recv()
if result is None:
any_failed = True
# Drain remaining workers to prevent stale results in pipes
for remaining_conn in workers_used:
if remaining_conn is not conn:
try:
remaining_conn.recv()
except Exception:
pass
break
all_worker_results.append(result)
if not any_failed:
rewards_per_func = torch.zeros(
num_prompts, len(self.reward_funcs), device=device
)
offset = 0
for worker_result in all_worker_results:
chunk_size = len(worker_result[0])
for i, result in enumerate(worker_result):
rewards_per_func[offset : offset + chunk_size, i] = torch.tensor(
result, dtype=torch.float32, device=device
)
offset += chunk_size
return gather(rewards_per_func)
# Fallback to main thread on failure
if args is not None:
return self._calculate_rewards(*args)
return self._calculate_rewards(
inputs, prompts, completions, completion_ids_list
)
def _post_advantage_hook(
self,
data: dict,
rewards_per_func,
advantages,
inputs: list,
num_generations: int,
mode: str,
s_start: int | None = None,
s_end: int | None = None,
is_last_chunk: bool = True,
) -> None:
"""Replay buffer store/replace + re-roll buffering."""
from trl.models.utils import disable_gradient_checkpointing
# -- Replay buffer: store high-signal groups --
if self._replay_buffer is not None:
local_grouped = rewards_per_func.view(
-1, num_generations, len(self.reward_funcs)
)
per_group_std = local_grouped.std(dim=1)
has_signal = (per_group_std > 0).any(dim=1)
offset = s_start or 0
if has_signal.any():
grouped_adv = advantages.view(-1, num_generations)
replay_scores = grouped_adv.abs().sum(dim=1) * per_group_std.sum(dim=1)
for group_idx in has_signal.nonzero(as_tuple=True)[0]:
gi = group_idx.item()
start = offset + gi * num_generations
end = start + num_generations
group_data = {}
for key in data:
val = data[key]
if (
isinstance(val, torch.Tensor)
and val.dim() > 0
and val.size(0) >= end
):
group_data[key] = val[start:end].clone()
self._replay_buffer.add(replay_scores[gi].item(), group_data)
# Replace zero-signal groups with high-signal replay buffer entries
# Only in non-streaming path (s_start is None) — streaming scores
# groups incrementally, so replacement + logprob recompute would be
# too expensive per chunk.
n_replaced = 0
if s_start is None:
no_signal = ~has_signal
replaced_ranges = []
if no_signal.any() and len(self._replay_buffer) > 0:
for group_idx in no_signal.nonzero(as_tuple=True)[0]:
sampled = self._replay_buffer.sample(1)
if sampled is None:
break
sampled_group = sampled[0]
gi = group_idx.item()
start = offset + gi * num_generations
end = start + num_generations
for key, val in sampled_group.items():
if key in data and isinstance(data[key], torch.Tensor):
src = val.to(data[key].device)
tgt_seq_len = (
data[key].size(1) if data[key].dim() > 1 else None
)
if start >= data[key].size(0) or end > data[key].size(
0
):
continue
if tgt_seq_len is not None:
if src.size(1) <= tgt_seq_len:
data[key][start:end] = 0
data[key][start:end, : src.size(1)] = src
else:
data[key][start:end] = src[:, :tgt_seq_len]
else:
data[key][start:end] = src
replaced_ranges.append((start, end))
n_replaced += 1
# Recompute old_per_token_logps for replayed groups
if (
n_replaced > 0
and self._replay_recompute_logps
and "old_per_token_logps" in data
):
with (
torch.no_grad(),
disable_gradient_checkpointing(
self.model, self.args.gradient_checkpointing_kwargs
),
):
for r_start, r_end in replaced_ranges:
r_ids = torch.cat(
[
data["prompt_ids"][r_start:r_end],
data["completion_ids"][r_start:r_end],
],
dim=1,
)
r_mask = torch.cat(
[
data["prompt_mask"][r_start:r_end],
data["completion_mask"][r_start:r_end],
],
dim=1,
)
r_logits_to_keep = data["completion_ids"].size(1)
r_fwd_kwargs = {}
for fk in (
"pixel_values",
"image_grid_thw",
"pixel_attention_mask",
"image_sizes",
"token_type_ids",
"mm_token_type_ids",
):
if fk in data:
r_fwd_kwargs[fk] = data[fk]
r_logps, _ = self._get_per_token_logps_and_entropies(
self.model,
r_ids,
r_mask,
r_logits_to_keep,
r_end - r_start,
**r_fwd_kwargs,
)
data["old_per_token_logps"][r_start:r_end] = r_logps
if n_replaced > 0:
self._metrics[mode]["replay_buffer_replacements"].append(
float(n_replaced)
)
if is_last_chunk:
self._metrics[mode]["replay_buffer_size"].append(
float(len(self._replay_buffer))
)
# -- Re-roll buffer: store failed prompts --
if getattr(self.args, "reroll_start_fraction", 1.0) < 1.0:
grouped_rewards = rewards_per_func.view(
-1, num_generations, len(self.reward_funcs)
)
per_group_std = grouped_rewards.std(dim=1)
per_group_mean = grouped_rewards.mean(dim=1)
zero_signal = (per_group_std == 0).all(dim=1)
all_failed = (per_group_mean.abs() < 1e-6).all(dim=1)
should_reroll = zero_signal & all_failed
_n_buffered = 0
with self._reroll_lock:
for group_idx in should_reroll.nonzero(as_tuple=True)[0]:
idx = group_idx.item() * num_generations
if idx >= len(inputs):
continue
prompt_input = inputs[idx]
self._reroll_buffer.append(prompt_input)
_n_buffered += 1
if _n_buffered > 0:
self._metrics[mode]["reroll_buffered"].append(float(_n_buffered))
if is_last_chunk:
self._metrics[mode]["reroll_buffer_size"].append(
float(len(self._reroll_buffer))
)
# -- Zero-advantage skipping + Liger OPSM ---------------------------------
def compute_liger_loss(self, unwrapped_model, inputs):
"""Liger loss with zero-adv skipping and off-policy sequence masking (OPSM).
The base class Liger path doesn't support OPSM because the fused kernel
doesn't expose per-token logprobs needed for the KL computation. This
override computes them via chunked lm_head matmul (no grad, low memory)
and applies the OPSM to the loss mask before calling the kernel.
"""
if self.args.skip_zero_advantage_batches and torch.all(
inputs["advantages"] == 0
):
mode = "train" if self.model.training else "eval"
self._metrics[mode]["skipped_zero_adv_batches"].append(1.0)
return torch.tensor(
0.0, device=inputs["advantages"].device, requires_grad=True
)
if self.off_policy_mask_threshold is None:
return super().compute_liger_loss(unwrapped_model, inputs)
# OPSM path: need per_token_logps for KL, which Liger kernel doesn't provide
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = (
inputs["completion_ids"],
inputs["completion_mask"],
)
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1)
last_hidden_state = self._get_last_hidden_state(
unwrapped_model,
input_ids,
attention_mask,
logits_to_keep,
inputs.get("pixel_values"),
inputs.get("image_grid_thw"),
inputs.get("pixel_attention_mask"),
inputs.get("image_sizes"),
)
loss_mask = (
completion_mask
if "tool_mask" not in inputs
else completion_mask * inputs["tool_mask"]
)
# Compute per_token_logps via chunked lm_head matmul (no grad, low memory)
lm_weight = unwrapped_model.lm_head.weight
lm_bias = unwrapped_model.lm_head.bias
with torch.no_grad():
per_token_logps_chunks = []
for i in range(last_hidden_state.size(0)):
chunk_logits = torch.matmul(last_hidden_state[i : i + 1], lm_weight.t())
if lm_bias is not None:
chunk_logits = chunk_logits + lm_bias
chunk_lps = (
chunk_logits.float()
.log_softmax(-1)
.gather(-1, completion_ids[i : i + 1].unsqueeze(-1))
.squeeze(-1)
)
per_token_logps_chunks.append(chunk_lps)
del chunk_logits
per_token_logps = torch.cat(per_token_logps_chunks, dim=0)
advantages = inputs["advantages"]
if advantages.dim() == 1:
advantages_2d = advantages.unsqueeze(1)
else:
advantages_2d = advantages
sampling_per_token_logps = inputs.get("sampling_per_token_logps")
if sampling_per_token_logps is None:
sampling_per_token_logps = inputs.get("old_per_token_logps")
if sampling_per_token_logps is None:
sampling_per_token_logps = per_token_logps
off_policy_mask = GRPOTrainer.get_off_policy_mask(
advantages=advantages_2d,
per_token_logps=per_token_logps,
sampling_per_token_logps=sampling_per_token_logps,
mask=loss_mask,
off_policy_threshold=self.off_policy_mask_threshold,
)
loss_mask = loss_mask * off_policy_mask
# Call the Liger fused kernel with OPSM-modified mask
loss, metrics = self.liger_grpo_loss(
_input=last_hidden_state,
lin_weight=unwrapped_model.lm_head.weight,
selected_token_ids=completion_ids,
attention_mask=loss_mask,
advantages=inputs["advantages"],
bias=unwrapped_model.lm_head.bias,
old_per_token_logps=inputs.get("old_per_token_logps"),
ref_per_token_logps=inputs.get("ref_per_token_logps"),
vllm_is_ratio=inputs.get("importance_sampling_ratio"),
)
mean_kl = metrics[0] if self.beta != 0.0 else None
clip_ratio = metrics[-1]
mode = "train" if self.model.training else "eval"
if self.beta != 0.0:
self._metrics[mode]["kl"].append(
self.accelerator.gather(mean_kl).mean().item()
)
self._metrics[mode]["clip_ratio"].append(
self.accelerator.gather(clip_ratio).mean().item()
)
normalizer = (
self.current_gradient_accumulation_steps if mode == "train" else 1.0
)
return loss / normalizer
def _compute_loss(self, model, inputs):
if self.args.skip_zero_advantage_batches and torch.all(
inputs["advantages"] == 0
):
mode = "train" if self.model.training else "eval"
self._metrics[mode]["skipped_zero_adv_batches"].append(1.0)
# Create zero loss with grad_fn. DeepSpeed requires grad_fn != None.
# With ZeRO-3, parameters are partitioned (shape=[0], requires_grad=False)
# so we can't just do `(p * 0).sum()`. Instead, do a tiny forward pass
# with a single token to create a proper computation graph.
prompt_ids = inputs["prompt_ids"][:1, :1] # (1, 1)
attn = torch.ones_like(prompt_ids)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
out = model(input_ids=prompt_ids, attention_mask=attn)
return out.logits.sum() * 0
return super()._compute_loss(model, inputs)

View File

@@ -0,0 +1,44 @@
"""Simple replay buffer for storing and sampling high-signal rollout groups."""
import heapq
import torch
class ReplayBuffer:
"""Min-heap replay buffer that keeps the highest-scoring rollout groups.
Groups are scored by signal quality (advantage magnitude * reward variance).
When sampling, groups are drawn proportional to their scores.
"""
def __init__(self, max_size: int):
self.max_size = max_size
self._heap: list[tuple[float, int, dict]] = [] # min-heap of (score, id, data)
self._counter = 0 # unique tiebreaker for heap
def __len__(self):
return len(self._heap)
def add(self, score: float, data: dict):
"""Add a group to the buffer. If full, replaces lowest-scoring entry."""
if self.max_size <= 0:
return
self._counter += 1
if len(self._heap) < self.max_size:
heapq.heappush(self._heap, (score, self._counter, data))
elif score > self._heap[0][0]:
heapq.heapreplace(self._heap, (score, self._counter, data))
def sample(self, num_samples: int) -> list[dict] | None:
"""Sample groups weighted by their scores. Returns None if buffer is empty."""
if self.max_size <= 0 or not self._heap:
return None
scores = torch.tensor([item[0] for item in self._heap], dtype=torch.float32)
scores = scores.clamp(min=1e-8) # avoid zero probabilities
probs = scores / scores.sum()
replacement = num_samples > len(self._heap)
indices = torch.multinomial(
probs, num_samples, replacement=replacement
).tolist()
return [self._heap[i][2] for i in indices]

View File

@@ -40,6 +40,7 @@ from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.grpo_trainer import RewardFunc, nanstd
from trl.trainer.utils import pad
from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOTrainer
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import (
DistributedParallelMixin,
@@ -66,6 +67,19 @@ class AxolotlGRPOTrainer(
_tag_names = ["trl", "grpo", "axolotl"]
class AxolotlAsyncGRPOTrainer(
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
FastAsyncGRPOTrainer,
):
"""Extend AsyncGRPOTrainer with axolotl helpers"""
_tag_names = ["trl", "grpo", "async", "axolotl"]
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling"""

View File

@@ -4,6 +4,7 @@
from .activation_checkpointing import ActivationOffloadingMixin
from .checkpoints import CheckpointSaveMixin
from .layer_offloading import LayerOffloadingMixin
from .distributed_parallel import DistributedParallelMixin
from .optimizer import OptimizerMixin
from .packing import PackingMixin

View File

@@ -19,5 +19,4 @@ class CheckpointSaveMixin(Trainer):
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
"Optimizer and scheduler states were not saved - resuming from checkpoints "
"for this training run will not be possible.",
main_process_only=True,
)

View File

@@ -0,0 +1,304 @@
"""
Trainer mixin for layer-wise parameter offloading to CPU.
Offloads frozen (non-trainable) parameters in decoder layers to CPU, then uses
forward/backward hooks to stream them on/off GPU one layer at a time with CUDA
stream prefetching. Trainable parameters (e.g. LoRA weights) stay on GPU always.
Forward: pre-hook loads layer N's frozen params to GPU (prefetches N+1 on
transfer stream), post-hook offloads layer N-1's frozen params.
Backward: same in reverse order.
"""
import contextlib
import torch
import torch.nn as nn
from transformers import Trainer
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _find_decoder_layers(model: nn.Module) -> tuple[nn.ModuleList | None, list[str]]:
"""Recursively search the model for the decoder layer ModuleList.
Finds any ModuleList whose children have 'DecoderLayer' in their class name.
Handles all common HF architectures including VLM wrappers (e.g. Qwen3.5-MoE
where layers are at model.language_model.layers).
"""
# BFS to find the first ModuleList containing decoder layers
queue = [model]
while queue:
m = queue.pop(0)
for _name, child in m.named_children():
if isinstance(child, nn.ModuleList) and len(child) > 0:
first_type = type(child[0]).__name__
if "DecoderLayer" in first_type or "TransformerBlock" in first_type:
layer_types = list({type(layer).__name__ for layer in child})
return child, layer_types
else:
queue.append(child)
return None, []
def _get_frozen_params(layer: nn.Module) -> list[tuple[str, nn.Parameter]]:
"""Get all non-trainable parameters in a layer."""
return [(n, p) for n, p in layer.named_parameters() if not p.requires_grad]
class LayerOffloadManager:
"""Manages offloading frozen decoder layer params to CPU and streaming
them back during forward/backward with CUDA stream overlap.
Only frozen (requires_grad=False) parameters are offloaded.
Trainable parameters (LoRA weights, etc.) remain on GPU at all times.
"""
def __init__(
self,
model: nn.Module,
num_prefetch: int = 1,
):
self.model = model
self.num_prefetch = num_prefetch
self._hooks: list = []
self._device = None
# Find decoder layers
self.layers, layer_types = _find_decoder_layers(model)
if self.layers is None:
LOG.warning(
"LayerOffloadManager: no decoder layers found, offloading disabled"
)
self.enabled = False
return
self.enabled = True
self.n_layers = len(self.layers)
LOG.info(
f"Layer offloading: found {self.n_layers} layers ({', '.join(layer_types)})"
)
# Determine GPU device
for p in model.parameters():
if p.device.type == "cuda":
self._device = p.device
break
if self._device is None:
LOG.warning("LayerOffloadManager: no CUDA parameters found")
self.enabled = False
return
# Transfer stream for async prefetch
self._transfer_stream = torch.cuda.Stream(device=self._device)
# Track which layers have their frozen params on GPU
self._on_gpu: set[int] = set(range(self.n_layers))
# Cache: frozen param references per layer (list of (name, param) tuples)
self._frozen_params: list[list[tuple[str, nn.Parameter]]] = [
_get_frozen_params(self.layers[i]) for i in range(self.n_layers)
]
# CPU storage: pinned tensors for each layer's frozen params
# Populated on first offload
self._cpu_data: list[dict[str, torch.Tensor]] = [
{} for _ in range(self.n_layers)
]
# Offload all layers upfront
self._offload_all()
# Release cached memory blocks back to the driver
torch.cuda.empty_cache()
def _offload_all(self):
"""Move all frozen params in all decoder layers to CPU."""
mem_before = torch.cuda.memory_allocated(self._device)
for i in range(self.n_layers):
self._offload_layer(i)
mem_after = torch.cuda.memory_allocated(self._device)
freed = (mem_before - mem_after) / 1e6
LOG.info(
f"Layer offloading: offloaded frozen params from {self.n_layers} layers, "
f"freed {freed:.0f} MB GPU memory"
)
def _offload_layer(self, idx: int):
"""Move frozen params of layer idx to CPU pinned memory."""
if idx not in self._on_gpu:
return
for name, param in self._frozen_params[idx]:
if param.device.type != "cuda":
continue
# Allocate pinned CPU tensor on first offload
if name not in self._cpu_data[idx]:
self._cpu_data[idx][name] = torch.empty_like(
param.data, device="cpu", pin_memory=True
)
cpu_buf = self._cpu_data[idx][name]
# Async copy GPU -> CPU (on transfer stream for overlap)
cpu_buf.copy_(param.data, non_blocking=True)
# Point parameter at a dummy CPU tensor to free GPU memory
param.data = cpu_buf
self._on_gpu.discard(idx)
def _load_layer(self, idx: int, stream=None):
"""Move frozen params of layer idx back to GPU."""
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
return
ctx = (
torch.cuda.stream(stream)
if stream is not None
else contextlib.nullcontext()
)
with ctx:
for _name, param in self._frozen_params[idx]:
if param.device.type == "cuda":
continue
gpu_data = param.data.to(self._device, non_blocking=True)
param.data = gpu_data
self._on_gpu.add(idx)
def _prefetch_layer(self, idx: int):
"""Async prefetch layer idx on the transfer stream."""
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
return
self._transfer_stream.wait_stream(torch.cuda.default_stream(self._device))
self._load_layer(idx, stream=self._transfer_stream)
def _wait_transfer(self):
"""Make default stream wait for any in-flight transfers."""
torch.cuda.default_stream(self._device).wait_stream(self._transfer_stream)
def setup_hooks(self):
"""Register forward and backward hooks on each decoder layer."""
if not self.enabled:
return
for idx in range(self.n_layers):
layer = self.layers[idx]
def make_pre_fwd(i):
def hook(module, args):
# Ensure this layer is on GPU
if i not in self._on_gpu:
self._load_layer(i)
self._wait_transfer()
# Prefetch next layer(s)
for offset in range(1, self.num_prefetch + 1):
self._prefetch_layer(i + offset)
return hook
def make_post_fwd(i):
def hook(module, args, output):
# Offload previous layer (no longer needed in forward)
if i > 0:
self._offload_layer(i - 1)
# Offload last layer after forward
if i == self.n_layers - 1:
self._offload_layer(i)
return hook
def make_pre_bwd(i):
def hook(module, grad_output):
# Load this layer for backward
if i not in self._on_gpu:
self._load_layer(i)
self._wait_transfer()
# Prefetch previous layer(s)
for offset in range(1, self.num_prefetch + 1):
self._prefetch_layer(i - offset)
return hook
def make_post_bwd(i):
def hook(module, grad_input, grad_output):
# Offload the layer above
if i < self.n_layers - 1:
self._offload_layer(i + 1)
# Offload first layer after backward
if i == 0:
self._offload_layer(i)
return hook
h1 = layer.register_forward_pre_hook(make_pre_fwd(idx))
h2 = layer.register_forward_hook(make_post_fwd(idx))
h3 = layer.register_full_backward_pre_hook(make_pre_bwd(idx))
h4 = layer.register_full_backward_hook(make_post_bwd(idx))
self._hooks.extend([h1, h2, h3, h4])
def remove_hooks(self):
"""Remove all hooks and restore layers to GPU."""
for h in self._hooks:
h.remove()
self._hooks.clear()
if self.enabled:
for i in range(self.n_layers):
if i not in self._on_gpu:
self._load_layer(i)
def pre_step(self):
"""Called before each training step — ensure layers start offloaded."""
if not self.enabled:
return
for i in list(self._on_gpu):
self._offload_layer(i)
# Prefetch layer 0 for forward
self._prefetch_layer(0)
def post_step(self):
"""Called after each training step — ensure layers are offloaded."""
if not self.enabled:
return
for i in list(self._on_gpu):
self._offload_layer(i)
# Prefetch layer 0 for next step
self._prefetch_layer(0)
class _LayerOffloadContext:
"""Context manager wrapping pre_step / post_step around a training step."""
def __init__(self, manager: LayerOffloadManager):
self.manager = manager
def __enter__(self):
self.manager.pre_step()
return self
def __exit__(self, *args):
self.manager.post_step()
class LayerOffloadingMixin(Trainer):
"""
Trainer mixin class for layer-wise parameter offloading to CPU.
Offloads frozen decoder layer params to CPU at init, then streams them
on/off GPU one layer at a time during each training step.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if getattr(self.args, "layer_offloading", False):
LOG.info("Layer parameter offloading enabled")
self._layer_offload_manager = LayerOffloadManager(
model=self.model,
num_prefetch=1,
)
self._layer_offload_manager.setup_hooks()
self._layer_offload_ctx = _LayerOffloadContext(self._layer_offload_manager)
else:
self._layer_offload_manager = None
self._layer_offload_ctx = contextlib.nullcontext()
def training_step(self, *args, **kwargs):
with self._layer_offload_ctx:
return super().training_step(*args, **kwargs)

View File

@@ -235,6 +235,13 @@ class AxolotlTrainingMixins:
metadata={"help": "Use activation offloading with CUDA streams for training."},
)
layer_offloading: bool | None = field(
default=None,
metadata={
"help": "Offload model layer parameters to CPU during forward, prefetch back during backward."
},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"
```
## Usage
@@ -73,8 +73,10 @@ plugins:
- ministral3
- mistral
- mistral3
- mistral4
- mixtral
- mllama
- nemotron_h
- olmo
- olmo2
- olmo3

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`'
)

View File

@@ -0,0 +1,120 @@
"""Trainer callback for reporting Triton autotune results from scattermoe-lora kernels."""
import logging
import torch
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
LOG = logging.getLogger(__name__)
# Give up looking for autotune data after this many training steps.
_MAX_POLL_STEP = 5
def _get_gpu_info() -> dict:
"""Return basic GPU identification for the current device."""
if not torch.cuda.is_available():
return {}
try:
idx = torch.cuda.current_device()
props = torch.cuda.get_device_properties(idx)
return {
"gpu_name": props.name,
"gpu_compute_capability": f"{props.major}.{props.minor}",
"gpu_memory_bytes": props.total_memory,
}
except Exception: # pylint: disable=broad-exception-caught
return {}
def _get_smem_capacity() -> dict:
"""Return shared memory capacity from the runtime lora_ops module."""
try:
from axolotl.integrations.kernels.autotune_collector import (
_find_lora_ops_module,
)
lora_ops = _find_lora_ops_module()
if lora_ops is None:
return {}
fn = getattr(lora_ops, "_get_smem_capacity", None)
if fn is None:
return {}
return {"smem_capacity_bytes": fn()}
except Exception: # pylint: disable=broad-exception-caught
return {}
class AutotuneReportCallback(TrainerCallback):
"""Reports Triton kernel autotune selections via telemetry.
Fires **once** after the first training step completes (step 1), at
which point the forward and backward passes have both run and the
autotuned kernels have populated their caches. If for some reason
the caches are still empty (e.g. the kernel was never invoked), the
callback retries on subsequent steps up to ``_MAX_POLL_STEP`` and
then stops polling.
After reporting (or giving up) every subsequent ``on_step_end``
call short-circuits on the ``_reported`` flag — zero hot-path cost.
"""
def __init__(self):
self._reported = False
# pylint: disable=unused-argument
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if self._reported:
return
# Lazy import — Triton / scattermoe kernels may not be installed.
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
configs = collect_autotune_configs()
if not configs:
if state.global_step >= _MAX_POLL_STEP:
LOG.debug(
"No autotune data found after %d steps; giving up.",
state.global_step,
)
self._reported = True
return
self._reported = True
from axolotl.telemetry.manager import TelemetryManager
telemetry_manager = TelemetryManager.get_instance()
if not telemetry_manager.enabled:
return
properties = {
"kernel_count": len(configs),
"kernels": configs,
}
properties.update(_get_gpu_info())
properties.update(_get_smem_capacity())
telemetry_manager.send_event(
event_type="scattermoe-autotune",
properties=properties,
)
LOG.info(
"Reported %d scattermoe kernel autotune config(s) to telemetry.",
len(configs),
)

View File

@@ -0,0 +1,114 @@
"""Collect Triton autotune results from scattermoe-lora kernels.
This module reads the ``.cache`` attribute from Triton ``@triton.autotune``
decorated kernel objects and returns structured dicts describing the selected
configurations. It has **no** telemetry dependency — callers decide what to
do with the data.
"""
import logging
import sys
from types import ModuleType
from typing import Any
LOG = logging.getLogger(__name__)
# (human-readable name, attribute on the lora_ops module)
_KERNEL_REGISTRY: list[tuple[str, str]] = [
("scatter2scatter_lora_fwd", "_scatter2scatter_lora"),
("scatter2scatter_lora_dX", "_scatter2scatter_lora_dX"),
("group_bwd_lora", "_group_bwd_lora"),
("group_bwd_lora_fused", "_group_bwd_lora_fused"),
]
# The autotune key declared on every kernel: key=["M", "N", "K"]
_KEY_NAMES: list[str] = ["M", "N", "K"]
def _parse_key_tuple(key_tuple: tuple) -> dict[str, Any]:
"""Turn the autotune cache key tuple into a labelled dict.
Triton builds the cache key from the values of the declared ``key``
args (``M``, ``N``, ``K``) followed by dtype signature elements.
We label the first three and store the rest under ``_extra``.
"""
result: dict[str, Any] = {}
for i, name in enumerate(_KEY_NAMES):
if i < len(key_tuple):
result[name] = key_tuple[i]
if len(key_tuple) > len(_KEY_NAMES):
result["_extra"] = [str(v) for v in key_tuple[len(_KEY_NAMES) :]]
return result
def _find_lora_ops_module() -> ModuleType | None:
"""Locate the *runtime* ``lora_ops`` module in ``sys.modules``.
The HF ``kernels`` package loads ``scattermoe_lora`` via
``import_from_path`` which registers it in ``sys.modules`` under a
hash-suffixed name (e.g. ``scattermoe_lora_a1b2c3d4``). A normal
import (``from axolotl.integrations.kernels...``) would create a
*separate* module instance whose kernel objects have empty
``.cache`` dicts because autotuning ran on the runtime copy.
We search ``sys.modules`` for any module whose name contains
``lora_ops`` and that has the ``_scatter2scatter_lora`` kernel
attribute — that is the runtime copy with populated caches.
"""
for name, module in list(sys.modules.items()):
if (
module is not None
and "lora_ops" in name
and hasattr(module, "_scatter2scatter_lora")
):
return module
return None
def collect_autotune_configs() -> list[dict[str, Any]]:
"""Read autotune caches from the four scattermoe-lora kernels.
Returns a (possibly empty) list of dicts, each containing:
* ``kernel`` human-readable kernel name
* ``key`` dict with the ``M``/``N``/``K`` problem dimensions
* ``config`` dict with the selected tile sizes, ``num_warps``,
and ``num_stages``
Returns ``[]`` if the kernel module cannot be found or if no
autotune cache entries exist yet.
"""
lora_ops = _find_lora_ops_module()
if lora_ops is None:
LOG.debug(
"lora_ops module not found in sys.modules; skipping autotune collection"
)
return []
results: list[dict[str, Any]] = []
for friendly_name, attr_name in _KERNEL_REGISTRY:
kernel_fn = getattr(lora_ops, attr_name, None)
if kernel_fn is None:
continue
cache = getattr(kernel_fn, "cache", None)
if not cache:
continue
for key_tuple, config in cache.items():
config_dict = dict(config.kwargs)
config_dict["num_warps"] = config.num_warps
config_dict["num_stages"] = config.num_stages
if getattr(config, "num_ctas", None) is not None:
config_dict["num_ctas"] = config.num_ctas
results.append(
{
"kernel": friendly_name,
"key": _parse_key_tuple(key_tuple),
"config": config_dict,
}
)
return results

View File

@@ -15,6 +15,7 @@ SPARSE_MOE_BLOCK = {
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
"qwen3_5_moe_text": "Qwen3_5MoeSparseMoeBlock",
"qwen3_next": "Qwen3NextSparseMoeBlock",
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
# qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate)
@@ -25,6 +26,8 @@ SPARSE_MOE_BLOCK = {
"olmoe": "OlmoeSparseMoeBlock",
"mixtral": "MixtralSparseMoeBlock",
"minimax": "MiniMaxSparseMoeBlock",
# softmax -> topk routing (with group-based expert selection)
"mistral4": "Mistral4MoE",
# sigmoid -> topk routing (with group-based expert selection)
"glm_moe_dsa": "GlmMoeDsaMoE",
"deepseek_v3": "DeepseekV3MoE",
@@ -56,7 +59,16 @@ def resolve_moe_block_classes(model_type: str):
cls_names = entry if isinstance(entry, list) else [entry]
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = importlib.import_module(module_path)
try:
module = importlib.import_module(module_path)
except ModuleNotFoundError:
# Text sub-model types (e.g. qwen3_5_moe_text) share the parent module
if model_type.endswith("_text"):
parent_type = model_type.removesuffix("_text")
module_path = f"transformers.models.{parent_type}.modeling_{parent_type}"
module = importlib.import_module(module_path)
else:
raise
classes = []
for cls_name in cls_names:

View File

@@ -195,6 +195,36 @@ def _estimate_smem_usage(
_SMEM_SLACK = 10_000
def _estimate_register_pressure(
num_warps: int,
*tile_sizes: tuple[int, int],
) -> float:
"""Rough estimate of per-thread register footprint from live tile sizes.
This is a heuristic, NOT an accurate register count. Triton uses tensor
core MMA fragments that pack multiple elements per register, and can spill
to local memory when the hardware limit (255 regs/thread) is exceeded.
The estimate is used to prune only truly extreme configs that would cause
excessive spilling or compilation failures. The threshold is set high
(``_MAX_REGS_SOFT_LIMIT``) because the heuristic overestimates — it
doesn't account for MMA fragment packing. Configs like M=64,N=64,K=64
(est ~520) work fine in practice via spilling.
Returns estimated registers per thread.
"""
# Each thread in a warp holds ~1/32 of the tile elements
tile_regs = sum(r * c for r, c in tile_sizes) / 32
scalar_overhead = 40
return tile_regs + scalar_overhead
# Soft limit for register pressure pruning. Only prune configs with extreme
# tile products (e.g. M=128,K=256,N=256) that reliably crash on Blackwell.
# Moderate configs (M=64,N=64,K=64, est ~520) work via register spilling.
_MAX_REGS_SOFT_LIMIT = 1024
# =============================================================================
# Forward Kernel: scatter2scatter with fused LoRA
# =============================================================================
@@ -313,12 +343,11 @@ def _compute_expert_block_lora(
B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0
) # [BLOCK_N, BLOCK_R]
# Cast xa_acc and b to same dtype for tl.dot (required when input is bf16/fp16)
# Both operands must match; cast to float32 (accumulator type) for precision.
b_f32 = b.to(tl.float32)
# tl.dot requires non-float32 inputs (tensor cores); cast back to input dtype
b_inp = b.to(INPUT_DTYPE)
# (X @ A^T) @ B^T: [M, R] @ [R, N] -> [M, N]
lora_out = tl.dot(xa_acc, tl.trans(b_f32), allow_tf32=allow_tf32)
lora_out = tl.dot(xa_acc.to(INPUT_DTYPE), tl.trans(b_inp), allow_tf32=allow_tf32)
acc += scaling * lora_out
return acc
@@ -327,28 +356,29 @@ def _compute_expert_block_lora(
def _scatter2scatter_lora_configs():
"""Generate forward kernel autotune configs.
Search space includes smaller tile sizes and fewer pipeline stages to
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
Search space includes BLOCK_M to allow trading token-tile size for
larger BLOCK_K/BLOCK_N tiles. On GPUs with ~99KB SMEM, BLOCK_M=128
forces BLOCK_K=32 and BLOCK_N=32; BLOCK_M=64 allows BLOCK_K=128
(4× fewer inner-loop iterations).
Search space:
BLOCK_N: {32, 64, 128, 256}
BLOCK_M: {32, 64, 128}
BLOCK_N: {32, 64}
BLOCK_K: {32, 64, 128}
num_warps: {4, 8}
num_stages: {3, 4, 5}
BLOCK_M is fixed at 128 (module-level constant, not autotuned in the
scatter2scatter pattern).
"""
configs = []
for block_n, block_k, warps, stages in product(
[32, 64, 128, 256], # BLOCK_N
for block_m, block_n, block_k, warps, stages in product(
[32, 64, 128], # BLOCK_M
[32, 64], # BLOCK_N
[32, 64, 128], # BLOCK_K
[4, 8], # num_warps
[3, 4, 5], # num_stages
):
configs.append(
triton.Config(
{"BLOCK_N": block_n, "BLOCK_K": block_k},
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k},
num_stages=stages,
num_warps=warps,
)
@@ -357,7 +387,7 @@ def _scatter2scatter_lora_configs():
def _prune_fwd_configs(configs, named_args, **kwargs):
"""Prune forward configs based on SMEM capacity.
"""Prune forward configs based on SMEM capacity and register pressure.
The forward kernel inner loop loads three tiles per pipeline stage:
X[BLOCK_M, BLOCK_K], W[BLOCK_K, BLOCK_N], A[BLOCK_R, BLOCK_K].
@@ -373,23 +403,49 @@ def _prune_fwd_configs(configs, named_args, **kwargs):
scored = []
for config in configs:
block_m = config.kwargs["BLOCK_M"]
block_n = config.kwargs["BLOCK_N"]
block_k = config.kwargs["BLOCK_K"]
# Base: stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_n, block_k)
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_n, block_k)
# A tile [BLOCK_R, BLOCK_K] loaded per stage in the inner loop
smem_lora_loop = config.num_stages * block_r * block_k * 2
# B tile [BLOCK_N, BLOCK_R] loaded once in epilogue
smem_lora_epilogue = block_n * block_r * 2
smem = smem_base + smem_lora_loop + smem_lora_epilogue
# Register pressure: live tiles are acc[M,N], xa_acc[M,R],
# x[M,K], w[K,N], a[R,K], plus epilogue b[N,R]
est_regs = _estimate_register_pressure(
config.num_warps,
(block_m, block_n), # acc
(block_m, block_r), # xa_acc
(block_m, block_k), # x tile
(block_k, block_n), # w tile
(block_r, block_k), # a tile
(block_n, block_r), # b tile (epilogue)
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
scored.append((smem, config))
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
if pruned:
return pruned
# All configs exceed SMEM — return the one with smallest estimated usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
if scored:
# All surviving configs exceed SMEM — return the one with smallest usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
# All configs pruned by register pressure — fall back to smallest tiles
return [
min(
configs,
key=lambda c: (
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_N"] * c.kwargs["BLOCK_K"]
),
)
]
@triton.autotune(
@@ -531,6 +587,89 @@ def _scatter2scatter_lora(
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
def _scatter2scatter_lora_split(
X: torch.Tensor,
W: torch.Tensor,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
k: int,
lora_A: torch.Tensor,
lora_B: torch.Tensor,
scaling: float,
b: Optional[torch.Tensor] = None,
x_grouped: bool = False,
y_grouped: bool = False,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Split base+LoRA forward: 3 scatter2scatter calls, no fused LoRA kernel.
Faster for models with few large experts (e.g. Mixtral E=8, I=14336)
because the base kernel runs at full speed without LoRA SMEM overhead,
and the LoRA matmuls (R=16) are tiny separate passes.
Y = scatter(X, W) + scaling * scatter(scatter(X, A^T), B^T)
"""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import (
scatter2scatter,
)
E = W.size(0)
R = lora_A.size(0) // E
K = W.size(1)
N = W.size(2)
# 1. Base: Y_base = X @ W (uses base kernel with optimal tile sizes)
output = scatter2scatter(
X=X,
W=W,
b=b,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=k,
x_grouped=x_grouped,
y_grouped=y_grouped,
out=out,
)
# 2. XA = X @ A^T (tiny: output is [M*k, R])
# Reshape A: [R*E, K] → [E, K, R] (expert weights for scatter2scatter)
W_A = lora_A.reshape(E, R, K).permute(0, 2, 1).contiguous()
XA = scatter2scatter(
X=X,
W=W_A,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=k,
x_grouped=x_grouped,
y_grouped=True,
)
# 3. Y_lora = XA @ B^T (R is tiny, so this is very fast)
# Reshape B: [N, R*E] → [E, R, N]
W_B = lora_B.T.reshape(E, R, N).contiguous()
Y_lora = scatter2scatter(
X=XA,
W=W_B,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
x_grouped=True,
y_grouped=y_grouped,
)
# 4. Y = Y_base + scaling * Y_lora
output.add_(Y_lora, alpha=scaling)
return output
# Threshold for switching from fused to split LoRA forward.
# Split wins when per-expert matmul is large (bandwidth-bound LoRA tile
# loads dominate in the fused kernel's inner loop).
# Empirically: split wins for E<=32 with K*N > 20M (e.g. Mixtral, Phi-MoE).
_SPLIT_LORA_FWD_THRESHOLD = 20_000_000 # per-expert K*N
_SPLIT_LORA_FWD_MAX_EXPERTS = 32
def scatter2scatter_lora(
X: torch.Tensor,
W: torch.Tensor,
@@ -546,7 +685,13 @@ def scatter2scatter_lora(
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Fused scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
Scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
Automatically selects between:
- Fused kernel: single Triton kernel with LoRA in the inner loop.
Best for many small experts (E>=64, small K*N).
- Split dispatch: 3 separate scatter2scatter calls (base + XA + lora).
Best for few large experts (E<=32, large K*N like Mixtral).
Args:
X: Input [M, K] or [M*k, K] if x_grouped
@@ -565,12 +710,30 @@ def scatter2scatter_lora(
Returns:
Y: Output [M*k, N]
"""
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
assert sorted_scattered_idxs.size(0) == X.size(0) * k
E = W.size(0)
K = W.size(1)
N = W.size(2)
# Dispatch: split for few large experts, fused for many small experts
if E <= _SPLIT_LORA_FWD_MAX_EXPERTS and K * N >= _SPLIT_LORA_FWD_THRESHOLD:
return _scatter2scatter_lora_split(
X,
W,
sorted_expert_idxs,
sorted_scattered_idxs,
k,
lora_A,
lora_B,
scaling,
b,
x_grouped,
y_grouped,
out,
)
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
assert sorted_scattered_idxs.size(0) == X.size(0) * k
R = lora_A.size(0) // E
# Pad R to power of 2 for Triton tile size
@@ -610,11 +773,9 @@ def scatter2scatter_lora(
b_ptr,
stride_be,
stride_bn,
# A: [r*E, K] -> stride(0) is r*E dim stride, stride(1) is K dim stride
lora_A,
lora_A.stride(0),
lora_A.stride(1),
# B: [N, r*E] -> stride(0) is N dim stride, stride(1) is r*E dim stride
lora_B,
lora_B.stride(0),
lora_B.stride(1),
@@ -625,9 +786,8 @@ def scatter2scatter_lora(
K=K,
N=N,
E=E,
ACTUAL_R=R, # True LoRA rank for weight indexing
BLOCK_M=BLOCK_M,
BLOCK_R=BLOCK_R, # Padded tile size >= max(R, 16)
ACTUAL_R=R,
BLOCK_R=BLOCK_R,
ACC_TYPE=tl.float32,
scaling=scaling,
allow_tf32=ALLOW_TF32,
@@ -761,13 +921,13 @@ def _compute_expert_block_lora_dX(
+ (A_expert_offset + R_block)[:, None] * stride_ar
+ K_block[None, :] * stride_ak
)
a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0)
# Cast to float32 for precision
a_f32 = a_e.to(tl.float32)
a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0).to(
INPUT_DTYPE
)
# (DY @ B) @ A: [M, R] @ [R, K] -> [M, K]
lora_dx = tl.dot(dy_b_acc, a_f32, allow_tf32=allow_tf32)
# tl.dot requires non-float32 inputs (tensor cores); cast accumulator back to input dtype
lora_dx = tl.dot(dy_b_acc.to(INPUT_DTYPE), a_e, allow_tf32=allow_tf32)
acc += scaling * lora_dx
return acc
@@ -779,25 +939,26 @@ def _scatter2scatter_lora_dX_configs():
The inner loop is over N (not K as in forward). The output dimension is K.
So BLOCK_K tiles the output and BLOCK_N tiles the reduction.
Search space includes smaller tile sizes and fewer pipeline stages to
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
BLOCK_M is now autotunable (was fixed at 128).
Search space:
BLOCK_K: {32, 64, 128, 256} (output tile)
BLOCK_N: {32, 64, 128, 256} (reduction tile)
BLOCK_M: {32, 64, 128} (token tile)
BLOCK_K: {32, 64, 128} (output tile)
BLOCK_N: {32, 64} (reduction tile)
num_warps: {4, 8}
num_stages: {3, 4, 5}
"""
configs = []
for block_k, block_n, warps, stages in product(
[32, 64, 128, 256], # BLOCK_K (output dimension)
[32, 64, 128, 256], # BLOCK_N (reduction dimension)
for block_m, block_k, block_n, warps, stages in product(
[32, 64, 128], # BLOCK_M
[32, 64, 128], # BLOCK_K (output dimension)
[32, 64], # BLOCK_N (reduction dimension)
[4, 8], # num_warps
[3, 4, 5], # num_stages
):
configs.append(
triton.Config(
{"BLOCK_K": block_k, "BLOCK_N": block_n},
{"BLOCK_M": block_m, "BLOCK_K": block_k, "BLOCK_N": block_n},
num_stages=stages,
num_warps=warps,
)
@@ -806,7 +967,7 @@ def _scatter2scatter_lora_dX_configs():
def _prune_dX_configs(configs, named_args, **kwargs):
"""Prune backward dX configs based on SMEM capacity.
"""Prune backward dX configs based on SMEM capacity and register pressure.
The dX kernel inner loop loads three tiles per pipeline stage:
DY[BLOCK_M, BLOCK_N], W^T[BLOCK_N, BLOCK_K], B[BLOCK_N, BLOCK_R].
@@ -822,23 +983,49 @@ def _prune_dX_configs(configs, named_args, **kwargs):
scored = []
for config in configs:
block_m = config.kwargs["BLOCK_M"]
block_k = config.kwargs["BLOCK_K"]
block_n = config.kwargs["BLOCK_N"]
# Base: stages * BLOCK_N * (BLOCK_M + BLOCK_K) + BLOCK_M * BLOCK_K
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_k, block_n)
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_k, block_n)
# B tile [BLOCK_N, BLOCK_R] loaded per stage in the inner loop
smem_lora_loop = config.num_stages * block_n * block_r * 2
# A tile [BLOCK_R, BLOCK_K] loaded once in epilogue
smem_lora_epilogue = block_r * block_k * 2
smem = smem_base + smem_lora_loop + smem_lora_epilogue
# Register pressure: live tiles are acc[M,K], dy_b_acc[M,R],
# dy[M,N], wt[N,K], b[N,R], plus epilogue a[R,K]
est_regs = _estimate_register_pressure(
config.num_warps,
(block_m, block_k), # acc
(block_m, block_r), # dy_b_acc
(block_m, block_n), # dy tile
(block_n, block_k), # wt tile
(block_n, block_r), # b tile
(block_r, block_k), # a tile (epilogue)
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
scored.append((smem, config))
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
if pruned:
return pruned
# All configs exceed SMEM — return the one with smallest estimated usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
if scored:
# All surviving configs exceed SMEM — return the one with smallest usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
# All configs pruned by register pressure — fall back to smallest tiles
return [
min(
configs,
key=lambda c: (
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"]
),
)
]
@triton.autotune(
@@ -1067,7 +1254,7 @@ def scatter2scatter_lora_dX(
N=N,
E=E,
ACTUAL_R=R,
BLOCK_M=BLOCK_M,
# BLOCK_M is autotuned (injected by triton.autotune from Config kwargs)
BLOCK_R=BLOCK_R,
ACC_TYPE=tl.float32,
scaling=scaling,
@@ -1091,9 +1278,9 @@ def _group_bwd_lora_configs():
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
Search space:
BLOCK_M: {32, 64, 128, 256} (token-loop tile)
BLOCK_K: {32, 64, 128, 256}
BLOCK_N: {32, 64, 128, 256}
BLOCK_M: {32, 64, 128} (token-loop tile)
BLOCK_K: {32, 64, 128}
BLOCK_N: {32, 64}
num_warps: {4, 8}
num_stages: {3, 4, 5}
@@ -1102,9 +1289,9 @@ def _group_bwd_lora_configs():
"""
configs = []
for block_m, block_k, block_n, warps, stages in product(
[32, 64, 128, 256], # BLOCK_M
[32, 64, 128, 256], # BLOCK_K
[32, 64, 128, 256], # BLOCK_N
[32, 64, 128], # BLOCK_M
[32, 64, 128], # BLOCK_K
[32, 64], # BLOCK_N
[4, 8], # num_warps
[3, 4, 5], # num_stages
):
@@ -1119,7 +1306,7 @@ def _group_bwd_lora_configs():
def _prune_bwd_lora_configs(configs, named_args, **kwargs):
"""Prune backward configs based on SMEM capacity.
"""Prune backward configs based on SMEM capacity and register pressure.
The backward kernel loads X[BLOCK_M, BLOCK_K] and DY[BLOCK_M, BLOCK_N]
in the inner loop, plus holds A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R]
@@ -1138,14 +1325,40 @@ def _prune_bwd_lora_configs(configs, named_args, **kwargs):
# A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R] held for the full expert
smem_lora = (block_r * block_k + block_n * block_r) * 2
smem = smem_base + smem_lora
# Register pressure: dA_acc[R,K], dB_acc[N,R], x[M,K], dy[M,N],
# a[R,K], b[N,R], xa[M,R], dy_b[M,R]
est_regs = _estimate_register_pressure(
config.num_warps,
(block_r, block_k), # dA_acc
(block_n, block_r), # dB_acc
(block_m, block_k), # x tile
(block_m, block_n), # dy tile
(block_r, block_k), # a tile
(block_n, block_r), # b tile
(block_m, block_r), # xa intermediate
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
scored.append((smem, config))
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
if pruned:
return pruned
# All configs exceed SMEM — return the one with smallest estimated usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
if scored:
# All surviving configs exceed SMEM — return the one with smallest usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
# All configs pruned by register pressure — fall back to smallest tiles
return [
min(
configs,
key=lambda c: (
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"]
),
)
]
@triton.autotune(
@@ -1330,6 +1543,279 @@ def _group_bwd_lora(
)
def _group_bwd_split_configs():
"""Autotune configs for split dA/dB kernels."""
configs = []
for block_m, block_dim, warps, stages in product(
[32, 64, 128], # BLOCK_M (token tile)
[32, 64, 128, 256], # BLOCK_DIM (K for dA, N for dB — output tile)
[4, 8], # num_warps
[3, 4, 5], # num_stages
):
configs.append(
triton.Config(
{"BLOCK_M": block_m, "BLOCK_DIM": block_dim},
num_stages=stages,
num_warps=warps,
)
)
return configs
def _prune_split_configs(configs, named_args, **kwargs):
"""Prune split kernel configs based on SMEM capacity and register pressure."""
smem_cap = _get_smem_capacity()
block_r = named_args.get("BLOCK_R", 64)
# Fixed inner tile for reduction dimension
BLOCK_INNER = 64
pruned = []
for config in configs:
block_m = config.kwargs["BLOCK_M"]
block_dim = config.kwargs["BLOCK_DIM"]
# Inner loop loads: input[M, INNER] and other[M, INNER_or_DIM]
smem = config.num_stages * BLOCK_INNER * (block_m + block_dim) * 2
# LoRA weights held in registers: [INNER, R] or [R, DIM]
smem += (block_r * max(block_dim, BLOCK_INNER)) * 2
# Register pressure check
est_regs = _estimate_register_pressure(
config.num_warps,
(block_r, block_dim), # acc
(block_m, BLOCK_INNER), # input tile
(block_m, block_dim), # other tile
(block_r, BLOCK_INNER), # lora weight
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
if smem <= smem_cap - _SMEM_SLACK:
pruned.append(config)
if pruned:
return pruned
configs.sort(key=lambda c: c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_DIM"])
return [configs[0]]
@triton.autotune(
configs=_group_bwd_split_configs(),
key=["M", "K", "N"],
prune_configs_by={"early_config_prune": _prune_split_configs},
)
@triton.heuristics(
{
"NO_DIM_MASK": lambda args: (
(args["K"] % args["BLOCK_DIM"]) == 0
if args["COMPUTE_DA"]
else (args["N"] % args["BLOCK_DIM"]) == 0
),
}
)
@triton.jit
def _group_bwd_lora_split(
# Data tensors (DY and X are always present)
DY_ptr,
stride_dym,
stride_dyn,
X_ptr,
stride_xm,
stride_xk,
# LoRA weight for the inner reduction (B for dA, A for dB)
LW_ptr,
stride_lw0,
stride_lw1,
# Output gradient tensor (dA or dB)
OUT_ptr,
stride_out0,
stride_out1,
# Expert offsets
expert_offsets_ptr,
# Dimensions
M,
K: tl.constexpr,
N: tl.constexpr,
ACTUAL_R: tl.constexpr,
BLOCK_R: tl.constexpr,
INNER_DIM: tl.constexpr, # reduction dimension (N for dA, K for dB)
scaling,
# Mode flag
COMPUTE_DA: tl.constexpr, # True = compute dA, False = compute dB
# Tile sizes
BLOCK_M: tl.constexpr,
BLOCK_DIM: tl.constexpr,
ACC_TYPE: tl.constexpr,
allow_tf32: tl.constexpr,
NO_DIM_MASK: tl.constexpr,
):
"""
Unified split kernel for LoRA gradient computation.
When COMPUTE_DA=True:
dA[e] = scaling * (dY @ B[e])^T @ X → [R, K]
Grid: (E, cdiv(K, BLOCK_DIM))
- outer_ptr/stride = X (read [M, K_block])
- inner reduction over N using DY and B
- output shape [BLOCK_R, BLOCK_DIM]
When COMPUTE_DA=False:
dB[e] = scaling * dY^T @ (X @ A[e]^T) → [N, R]
Grid: (E, cdiv(N, BLOCK_DIM))
- outer_ptr/stride = DY (read [M, N_block])
- inner reduction over K using X and A
- output shape [BLOCK_DIM, BLOCK_R]
No atomic adds — each (E, dim_block) pair is written by exactly one block.
"""
E_idx = tl.program_id(0)
dim_block_id = tl.program_id(1)
if E_idx == 0:
start_idx = 0
else:
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
num_tokens = end_idx - start_idx
# Output dimension tile (K for dA, N for dB)
if COMPUTE_DA:
OUT_DIM: tl.constexpr = K # type: ignore[no-redef]
else:
OUT_DIM: tl.constexpr = N # type: ignore[no-redef]
dim_block = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
dim_mask = dim_block < OUT_DIM
R_block = tl.arange(0, BLOCK_R)
R_mask = R_block < ACTUAL_R
lora_offset = E_idx * ACTUAL_R
# Output pointers — layout differs: dA is [R, K], dB is [N, R]
if COMPUTE_DA:
out_blk_ptrs = (
OUT_ptr
+ (lora_offset + R_block)[:, None] * stride_out0
+ dim_block[None, :] * stride_out1
)
out_mask = R_mask[:, None] & dim_mask[None, :]
else:
out_blk_ptrs = (
OUT_ptr
+ dim_block[:, None] * stride_out0
+ (lora_offset + R_block)[None, :] * stride_out1
)
out_mask = dim_mask[:, None] & R_mask[None, :]
if num_tokens > 0:
M_block = tl.arange(0, BLOCK_M)
INPUT_DTYPE = X_ptr.dtype.element_ty
BLOCK_INNER: tl.constexpr = 64
inner_iters = tl.cdiv(INNER_DIM, BLOCK_INNER)
if COMPUTE_DA:
acc = tl.zeros((BLOCK_R, BLOCK_DIM), dtype=ACC_TYPE)
else:
acc = tl.zeros((BLOCK_DIM, BLOCK_R), dtype=ACC_TYPE)
M_iters = tl.cdiv(num_tokens, BLOCK_M)
for i in range(M_iters):
M_idx = start_idx + i * BLOCK_M + M_block
M_mask = M_idx < end_idx
if COMPUTE_DA:
# Load X[M, K_block] (the "outer" tensor for dA)
outer = tl.load(
X_ptr + M_idx[:, None] * stride_xm + dim_block[None, :] * stride_xk,
mask=M_mask[:, None] & dim_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# Reduce DY[M, :] @ B[e][:, R] over N → [M, R]
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
inner_range = tl.arange(0, BLOCK_INNER)
for j in range(inner_iters):
inn_off = j * BLOCK_INNER + inner_range
inn_mask = inn_off < N
dy_tile = tl.load(
DY_ptr
+ M_idx[:, None] * stride_dym
+ inn_off[None, :] * stride_dyn,
mask=M_mask[:, None] & inn_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# B layout: [N, r*E] → stride_lw0=N stride, stride_lw1=r*E stride
lw_tile = tl.load(
LW_ptr
+ inn_off[:, None] * stride_lw0
+ (lora_offset + R_block)[None, :] * stride_lw1,
mask=inn_mask[:, None] & R_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
reduced += tl.dot(dy_tile, lw_tile, allow_tf32=allow_tf32)
# dA += (DY@B)^T @ X: [R, M] @ [M, K_block] → [R, K_block]
acc += tl.dot(
tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32
)
else:
# Load DY[M, N_block] (the "outer" tensor for dB)
outer = tl.load(
DY_ptr
+ M_idx[:, None] * stride_dym
+ dim_block[None, :] * stride_dyn,
mask=M_mask[:, None] & dim_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# Reduce X[M, :] @ A[e][:, :].T over K → [M, R]
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
inner_range = tl.arange(0, BLOCK_INNER)
for j in range(inner_iters):
inn_off = j * BLOCK_INNER + inner_range
inn_mask = inn_off < K
x_tile = tl.load(
X_ptr
+ M_idx[:, None] * stride_xm
+ inn_off[None, :] * stride_xk,
mask=M_mask[:, None] & inn_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# A layout: [r*E, K] → stride_lw0=r*E stride, stride_lw1=K stride
# We want A[e]^T: [K, R], so load as [K_inner, R]
lw_tile = tl.load(
LW_ptr
+ (lora_offset + R_block)[None, :] * stride_lw0
+ inn_off[:, None] * stride_lw1,
mask=inn_mask[:, None] & R_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
reduced += tl.dot(x_tile, lw_tile, allow_tf32=allow_tf32)
# dB += DY^T @ (X@A^T): [N_block, M] @ [M, R] → [N_block, R]
acc += tl.dot(
tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32
)
tl.store(
out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask
)
else:
# Zero out this expert's slice — needed because output uses empty_like
if COMPUTE_DA:
tl.store(
out_blk_ptrs,
tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty),
mask=out_mask,
)
else:
tl.store(
out_blk_ptrs,
tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty),
mask=out_mask,
)
def group_bwd_lora(
DY: torch.Tensor,
X: torch.Tensor,
@@ -1344,6 +1830,9 @@ def group_bwd_lora(
"""
Compute LoRA gradients for A and B on expert-grouped data.
Uses split dA/dB kernels that eliminate atomic adds by giving each
(expert, output_block) pair its own thread block.
Args:
DY: Gradient w.r.t. output [M_total, N] (grouped by expert)
X: Input [M_total, K] (grouped by expert)
@@ -1361,19 +1850,46 @@ def group_bwd_lora(
K = X.size(1)
N = DY.size(1)
# Zero-init for atomic accumulation
dA = torch.zeros_like(lora_A)
dB = torch.zeros_like(lora_B)
# No zero-init needed: the split kernels write zeros for experts with
# zero routed tokens directly in the kernel (else branch).
dA = torch.empty_like(lora_A)
dB = torch.empty_like(lora_B)
BLOCK_R = _block_r_for_rank(R)
def grid(META):
return (
E * triton.cdiv(K, META["BLOCK_K"]),
triton.cdiv(N, META["BLOCK_N"]),
)
def grid_dA(META):
return (E, triton.cdiv(K, META["BLOCK_DIM"]))
_group_bwd_lora[grid](
_group_bwd_lora_split[grid_dA](
DY,
DY.stride(0),
DY.stride(1),
X,
X.stride(0),
X.stride(1),
lora_B,
lora_B.stride(0),
lora_B.stride(1),
dA,
dA.stride(0),
dA.stride(1),
expert_offsets,
M=DY.size(0),
K=K,
N=N,
ACTUAL_R=R,
BLOCK_R=BLOCK_R,
INNER_DIM=N,
scaling=scaling,
COMPUTE_DA=True,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
)
def grid_dB(META):
return (E, triton.cdiv(N, META["BLOCK_DIM"]))
_group_bwd_lora_split[grid_dB](
DY,
DY.stride(0),
DY.stride(1),
@@ -1383,12 +1899,6 @@ def group_bwd_lora(
lora_A,
lora_A.stride(0),
lora_A.stride(1),
lora_B,
lora_B.stride(0),
lora_B.stride(1),
dA,
dA.stride(0),
dA.stride(1),
dB,
dB.stride(0),
dB.stride(1),
@@ -1396,9 +1906,11 @@ def group_bwd_lora(
M=DY.size(0),
K=K,
N=N,
ACTUAL_R=R, # True LoRA rank
BLOCK_R=BLOCK_R, # Padded tile size
ACTUAL_R=R,
BLOCK_R=BLOCK_R,
INNER_DIM=K,
scaling=scaling,
COMPUTE_DA=False,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
)

View File

@@ -220,6 +220,158 @@ def _unwrap_experts_lora(experts_module):
return base_experts, gup_lora, down_lora
# =============================================================================
# Routing helpers
# =============================================================================
def _softmax_topk_route(
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
):
"""Softmax→topk routing (Qwen, OLMoE, Mixtral, MiniMax).
Returns:
(routing_weights [T, K], selected_experts [T, K], top_k, num_experts)
"""
router_logits = F.linear(hidden_states, gate_weight)
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(hidden_states, gate_lora_delta)
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
top_k = base_gate.top_k
num_experts = base_gate.num_experts
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
if getattr(base_gate, "norm_topk_prob", True):
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
return routing_weights, selected_experts, top_k, num_experts
def _sigmoid_topk_route(
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
):
"""Sigmoid→topk routing (GLM, DeepSeek V3, MiniMax M2).
Supports:
- ``e_score_correction_bias`` on gate or moe_block
- Group-based expert selection when ``n_group > 1``
- ``routed_scaling_factor`` applied to final weights
- Final weights gathered from original sigmoid probs (not bias-corrected)
Returns:
(routing_weights [T, K], selected_experts [T, K], top_k, num_experts)
"""
router_logits = F.linear(hidden_states.float(), gate_weight.float())
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states.float(), gate_lora_delta.float()
)
router_probs = router_logits.sigmoid() # [T, E]
top_k = getattr(moe_block, "top_k", getattr(base_gate, "top_k", None))
num_experts = getattr(moe_block, "n_routed_experts", gate_weight.shape[0])
# Bias-corrected scores for expert selection (not used for final weights).
# glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 on the block.
e_score_correction_bias = getattr(base_gate, "e_score_correction_bias", None)
if e_score_correction_bias is None:
e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None)
if e_score_correction_bias is not None:
scores_for_choice = router_probs + e_score_correction_bias
else:
scores_for_choice = router_probs
# Group-based selection: pick top groups, mask the rest
n_group = getattr(moe_block, "n_group", 1)
if n_group > 1:
group_scores = (
scores_for_choice.view(-1, n_group, num_experts // n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
) # [T, n_group]
topk_group = getattr(moe_block, "topk_group", n_group)
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(-1, n_group, num_experts // n_group)
.reshape(-1, num_experts)
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
# Final topk from (possibly masked) scores
topk_indices = torch.topk(scores_for_choice, k=top_k, dim=-1, sorted=False)[1]
# Gather weights from original sigmoid scores (not bias-corrected)
topk_weights = router_probs.gather(1, topk_indices)
# Optional renormalization + scaling
if getattr(moe_block, "norm_topk_prob", True):
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
topk_weights = topk_weights * routed_scaling_factor
return topk_weights, topk_indices, top_k, num_experts
def _route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):
"""Dispatch to the correct routing strategy based on block attributes.
Detects sigmoid routing by the presence of ``e_score_correction_bias``
on either the gate or the moe_block.
"""
has_sigmoid = (
getattr(base_gate, "e_score_correction_bias", None) is not None
or getattr(moe_block, "e_score_correction_bias", None) is not None
)
if has_sigmoid:
return _sigmoid_topk_route(
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
)
return _softmax_topk_route(
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
)
# =============================================================================
# Shared expert helpers
# =============================================================================
def _compute_shared_expert(moe_block, hidden_states_flat):
"""Compute shared expert output if the block has one.
Handles singular (qwen2_moe: ``shared_expert``), plural
(glm_moe_dsa/deepseek_v3: ``shared_experts``), and MLP
(hunyuan_v1_moe: ``shared_mlp``) attribute names.
peft wraps individual linear layers inside the shared expert with
standard LoRA — calling forward() handles this transparently.
"""
shared_expert = (
getattr(moe_block, "shared_expert", None)
or getattr(moe_block, "shared_experts", None)
or getattr(moe_block, "shared_mlp", None)
)
if shared_expert is None:
return None
shared_expert_output = shared_expert(hidden_states_flat)
# Optional sigmoid gate (Qwen2MoE pattern).
# shared_expert_gate may also be peft-wrapped (standard LoRA
# on nn.Linear), its forward() applies LoRA automatically.
shared_expert_gate = getattr(moe_block, "shared_expert_gate", None)
if shared_expert_gate is not None:
shared_expert_output = (
F.sigmoid(shared_expert_gate(hidden_states_flat)) * shared_expert_output
)
return shared_expert_output
# =============================================================================
# Layer classes
# =============================================================================
@@ -281,16 +433,18 @@ class ScatterMoEGatedMLP(nn.Module):
class HFScatterMoEGatedMLP(nn.Module):
"""
ScatterMoE-accelerated forward pass for HF MoEs (OLMoE / Qwen2MoE).
ScatterMoE-accelerated forward pass for HF MoEs.
Used as a kernel layer via the HF ``kernels`` library. The ``forward``
method replaces the original ``OlmoeSparseMoeBlock.forward``.
method replaces the original SparseMoeBlock.forward.
Supports both full-parameter training and LoRA fine-tuning:
Supports:
* **Full-param**: uses ``parallel_linear`` (base ScatterMoE kernel)
* **LoRA**: detects peft ``ParamWrapper`` on ``self.experts``, extracts
adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
* **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2
* **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE)
* **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``,
extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
"""
@staticmethod
@@ -302,7 +456,7 @@ class HFScatterMoEGatedMLP(nn.Module):
self: The MoeSparseMoeBlock module containing:
- self.gate: Router (or peft ParamWrapper wrapping it)
- self.experts: Experts module (or peft ParamWrapper chain)
- self.shared_expert: Optional shared expert (e.g. Qwen2MoE)
- self.shared_expert(s): Optional shared expert
- self.shared_expert_gate: Optional shared expert gate
layer_input: Input tensor [batch_size, seq_len, hidden_size]
@@ -313,38 +467,17 @@ class HFScatterMoEGatedMLP(nn.Module):
hidden_states_flat = layer_input.view(-1, hidden_dim)
# ====================================================================
# Shared Expert (if present, e.g. Qwen2MoE)
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3)
# ====================================================================
# peft wraps individual linear layers inside shared_expert with
# standard LoRA — calling forward() handles this transparently.
if hasattr(self, "shared_expert") and self.shared_expert is not None:
shared_expert_output = self.shared_expert(hidden_states_flat)
# shared_expert_gate may also be peft-wrapped (standard LoRA
# on nn.Linear), its forward() applies LoRA automatically.
shared_expert_gate_output = F.sigmoid(
self.shared_expert_gate(hidden_states_flat)
)
shared_expert_output = shared_expert_output * shared_expert_gate_output
else:
shared_expert_output = None
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
# ====================================================================
# Router Computation (with optional gate LoRA)
# ====================================================================
base_gate, gate_weight, gate_lora_delta = _unwrap_gate_lora(self.gate)
router_logits = F.linear(hidden_states_flat, gate_weight)
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states_flat, gate_lora_delta
)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
top_k = base_gate.top_k
num_experts = base_gate.num_experts
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
if base_gate.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights, selected_experts, top_k, num_experts = _route(
self, base_gate, hidden_states_flat, gate_weight, gate_lora_delta
)
routing_weights = routing_weights.to(hidden_states_flat.dtype)
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
@@ -356,20 +489,71 @@ class HFScatterMoEGatedMLP(nn.Module):
# ====================================================================
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
# ====================================================================
# Selective expert weight dequantization
# ====================================================================
# When experts are BnB-quantized (quantize_moe_experts), dequantize
# only the active experts instead of all E. This saves ~97% memory
# for the transient dequant buffer when few experts are active.
use_selective = (
getattr(self, "_use_selective_dequant", False)
and hasattr(experts, "parametrizations")
and "gate_up_proj" in experts.parametrizations
)
if use_selective:
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant import (
get_active_experts,
remap_expert_indices,
selective_expert_weights,
selective_lora_weights,
)
active_experts = get_active_experts(sorted_expert_idxs, num_experts)
remapped_expert_idxs, compact_offsets = remap_expert_indices(
sorted_expert_idxs,
expert_offsets,
active_experts,
num_experts,
)
# Dequantize only active experts' weights
gate_up_W = selective_expert_weights(
experts,
"gate_up_proj",
active_experts,
).transpose(2, 1) # [num_active, hidden, 2*inter]
# Remap LoRA weights to match compact expert indices
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
gup_A, gup_B = selective_lora_weights(
gup_A,
gup_B,
active_experts,
num_experts,
)
gup_lora = (gup_A, gup_B, gup_scaling)
# Use remapped indices for ScatterMoE kernels
sei_gup = remapped_expert_idxs
eo_gup = compact_offsets
else:
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
sei_gup = sorted_expert_idxs
eo_gup = expert_offsets
# ====================================================================
# Gate + Up projection
# ====================================================================
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
gup = parallel_linear_lora(
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sei_gup,
sorted_scattered_idxs,
expert_offsets,
eo_gup,
lora_A=gup_A,
lora_B=gup_B,
scaling=gup_scaling,
@@ -383,9 +567,9 @@ class HFScatterMoEGatedMLP(nn.Module):
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sei_gup,
sorted_scattered_idxs,
expert_offsets,
eo_gup,
grouped_in=False,
grouped_out=True,
)
@@ -396,7 +580,29 @@ class HFScatterMoEGatedMLP(nn.Module):
# ====================================================================
# Down projection
# ====================================================================
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
if use_selective:
down_W = selective_expert_weights(
experts,
"down_proj",
active_experts,
).transpose(2, 1) # [num_active, inter, hidden]
if down_lora is not None:
down_A, down_B, down_scaling = down_lora
down_A, down_B = selective_lora_weights(
down_A,
down_B,
active_experts,
num_experts,
)
down_lora = (down_A, down_B, down_scaling)
sei_down = remapped_expert_idxs
eo_down = compact_offsets
else:
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
sei_down = sorted_expert_idxs
eo_down = expert_offsets
if down_lora is not None:
down_A, down_B, down_scaling = down_lora
@@ -404,9 +610,9 @@ class HFScatterMoEGatedMLP(nn.Module):
h,
down_W,
1,
sorted_expert_idxs,
sei_down,
sorted_scattered_idxs,
expert_offsets,
eo_down,
lora_A=down_A,
lora_B=down_B,
scaling=down_scaling,
@@ -421,9 +627,9 @@ class HFScatterMoEGatedMLP(nn.Module):
h,
down_W,
1,
sorted_expert_idxs,
sei_down,
sorted_scattered_idxs,
expert_offsets,
eo_down,
grouped_in=True,
grouped_out=False,
gates=routing_weights,

View File

@@ -0,0 +1,282 @@
"""
Selective Expert Dequantization
===============================
Instead of dequantizing all E expert weight matrices at once (which creates
a ~1 GB transient buffer for 256 experts), only dequantize the experts that
are actually routed to by the current batch's top-k selection.
For Qwen3.5-35B-A3B (E=256, top_k=8, hidden=2048, intermediate=512):
- Full dequant: [256, 2048, 1024] = 1,074 MB per projection
- Selective (8 active): [8, 2048, 1024] = 33.5 MB per projection
- Savings: ~97% memory reduction per layer
This module provides format-agnostic selective weight extraction:
- BnB 4-bit (nf4/fp4): slice quantized data + absmax per expert
- bf16/fp32: direct indexing (no dequant needed)
- FP8: slice + cast
The ScatterMoE kernel itself doesn't change — we remap expert indices
from global (0..E-1) to compact (0..num_active-1) and pass the smaller
weight tensor.
"""
import torch
import torch.nn as nn
def get_active_experts(sorted_expert_idxs: torch.Tensor, E: int) -> torch.Tensor:
"""Get sorted unique expert indices from the routing output.
Args:
sorted_expert_idxs: Expert assignments sorted by expert id [T*k]
E: Total number of experts
Returns:
active: Sorted unique expert indices [num_active]
"""
return torch.unique(sorted_expert_idxs)
def remap_expert_indices(
sorted_expert_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
active_experts: torch.Tensor,
E: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Remap global expert indices to compact indices.
Maps expert ids from [0..E-1] to [0..num_active-1], preserving the
sort order. Also compacts expert_offsets to only active experts.
Args:
sorted_expert_idxs: [T*k] expert ids in sorted order
expert_offsets: [E] cumulative token counts (original)
active_experts: [num_active] sorted unique expert ids
E: Total number of experts
Returns:
remapped_idxs: [T*k] expert ids in [0..num_active-1]
compact_offsets: [num_active] cumulative token counts
"""
# Build remap table: global_id -> compact_id
remap = torch.empty(E, dtype=torch.long, device=sorted_expert_idxs.device)
remap[active_experts] = torch.arange(
len(active_experts), device=sorted_expert_idxs.device
)
remapped_idxs = remap[sorted_expert_idxs]
# Compact the expert_offsets: only keep active experts' cumulative counts
compact_offsets = expert_offsets[active_experts]
return remapped_idxs, compact_offsets
def _selective_dequant_bnb4(
raw_param: torch.Tensor,
quant_state,
active_experts: torch.Tensor,
expert_shape: tuple[int, int],
) -> torch.Tensor:
"""Dequantize only selected experts from BnB 4-bit packed data.
The raw parameter is a flattened 4-bit packed tensor. Each expert's
data is contiguous (stored in expert-major order), so we can gather
the packed data and absmax blocks for active experts, then dequantize
as one contiguous block.
Args:
raw_param: Flattened uint8 tensor of packed 4-bit weights
quant_state: BnB QuantState with absmax, blocksize, code, etc.
active_experts: [num_active] expert indices to dequantize
expert_shape: (dim1, dim2) shape per expert (e.g. (1024, 2048))
Returns:
Dequantized weights [num_active, dim1, dim2] in original dtype
"""
import bitsandbytes.functional as F # noqa: N812
from bitsandbytes.functional import QuantState
expert_numel = expert_shape[0] * expert_shape[1]
packed_per_expert = expert_numel // 2 # 4-bit = 2 values per byte
blocks_per_expert = expert_numel // quant_state.blocksize
num_active = len(active_experts)
if blocks_per_expert == 0:
# Expert is smaller than one quantization block — blocks span across
# expert boundaries, so per-expert slicing isn't possible.
# Fallback: full dequantize + index.
full = F.dequantize_4bit(raw_param, quant_state)
E_total = full.numel() // expert_numel
return full.reshape(E_total, *expert_shape)[active_experts]
# Use fused Triton kernel for NF4 (handles selective gather + dequant in one pass)
if quant_state.quant_type == "nf4" and raw_param.dtype == torch.uint8:
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant_kernel import (
selective_dequant_nf4_triton,
)
# Handle nested (double) quantization: dequantize absmax first
# BnB uses dequantize_blockwise (not _4bit) for nested absmax + offset
if quant_state.nested:
absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()
else:
absmax = quant_state.absmax
return selective_dequant_nf4_triton(
packed_data=raw_param,
absmax=absmax,
active_experts=active_experts,
expert_shape=expert_shape,
blocksize=quant_state.blocksize,
dtype=quant_state.dtype,
codebook=quant_state.code,
)
# Fallback: gather + BnB dequant (for fp4 or non-uint8 packed formats)
raw_flat = raw_param.reshape(-1)
offsets_qt = (
active_experts.long()[:, None] * packed_per_expert
+ torch.arange(packed_per_expert, device=raw_param.device)[None, :]
).reshape(-1)
qt_gathered = raw_flat[offsets_qt]
offsets_abs = (
active_experts.long()[:, None] * blocks_per_expert
+ torch.arange(blocks_per_expert, device=raw_param.device)[None, :]
).reshape(-1)
if quant_state.nested:
full_absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
full_absmax += quant_state.offset
if full_absmax.dtype != torch.float32:
full_absmax = full_absmax.float()
absmax_gathered = full_absmax[offsets_abs]
else:
absmax_gathered = quant_state.absmax[offsets_abs]
qt_gathered = qt_gathered.unsqueeze(1) if qt_gathered.dim() == 1 else qt_gathered
gathered_qs = QuantState(
absmax=absmax_gathered,
shape=torch.Size([num_active * expert_numel]),
blocksize=quant_state.blocksize,
quant_type=quant_state.quant_type,
code=quant_state.code,
dtype=quant_state.dtype,
)
deq = F.dequantize_4bit(qt_gathered, gathered_qs)
return deq.reshape(num_active, *expert_shape)
def _selective_index_dense(
param: torch.Tensor,
active_experts: torch.Tensor,
) -> torch.Tensor:
"""Select experts from a dense (bf16/fp32) weight tensor.
Simple indexing — no dequantization needed.
"""
return param[active_experts]
def selective_expert_weights(
experts_module: nn.Module,
param_name: str,
active_experts: torch.Tensor,
) -> torch.Tensor:
"""Extract and dequantize only the active experts' weights.
Format-agnostic: dispatches based on whether the parameter is
BnB 4-bit quantized (via parametrize), FP8, or dense bf16/fp32.
Args:
experts_module: The base experts module (e.g. Qwen3_5MoeExperts)
param_name: "gate_up_proj" or "down_proj"
active_experts: [num_active] sorted unique expert indices
Returns:
Compact weight tensor [num_active, dim1, dim2] ready for ScatterMoE
"""
# Check if the parameter is BnB-quantized via parametrize
if (
hasattr(experts_module, "parametrizations")
and param_name in experts_module.parametrizations
):
param_list = experts_module.parametrizations[param_name]
parametrization = param_list[0]
# BnB 4-bit parametrization
if hasattr(parametrization, "quant_state"):
# The raw quantized data is on the ParametrizationList, not the
# individual Bnb4bitParametrization module
raw_param = param_list.original
qs = parametrization.quant_state
# qs.shape is the original tensor shape before flattening.
# For MoE experts it's [E, d1, d2] (3D) or [total_elements] (1D).
orig_shape = qs.shape
if isinstance(orig_shape, torch.Size) and len(orig_shape) == 3:
expert_shape = (orig_shape[1], orig_shape[2])
elif isinstance(orig_shape, torch.Size) and len(orig_shape) == 1:
# Flattened — need to infer from module attributes
E_total = getattr(experts_module, "num_experts", None)
if E_total is None:
E_total = int(active_experts.max().item()) + 1
expert_numel = orig_shape[0] // E_total
d2 = getattr(experts_module, "hidden_dim", None) or getattr(
experts_module, "intermediate_dim", None
)
if d2 and expert_numel % d2 == 0:
expert_shape = (expert_numel // d2, d2)
else:
full = getattr(experts_module, param_name)
return full[active_experts]
else:
full = getattr(experts_module, param_name)
return full[active_experts]
return _selective_dequant_bnb4(raw_param, qs, active_experts, expert_shape)
# Dense parameter (bf16/fp32) — direct indexing
param = getattr(experts_module, param_name)
if param.dim() == 3:
return param[active_experts]
# Fallback: full access
return param
def selective_lora_weights(
lora_A: torch.Tensor,
lora_B: torch.Tensor,
active_experts: torch.Tensor,
E: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Select LoRA A and B weights for only the active experts.
LoRA layout (scattermoe format):
A: [r*E, K] — expert e occupies rows [e*r : (e+1)*r]
B: [N, r*E] — expert e occupies cols [e*r : (e+1)*r]
Returns compact:
A: [r*num_active, K]
B: [N, r*num_active]
"""
R = lora_A.size(0) // E
# Vectorized gather: active_experts[:, None] * R + arange(R)[None, :]
row_idx = (
active_experts.long()[:, None] * R
+ torch.arange(R, device=lora_A.device)[None, :]
).reshape(-1)
compact_A = lora_A[row_idx] # [r*num_active, K]
compact_B = lora_B[:, row_idx] # [N, r*num_active]
return compact_A, compact_B

View File

@@ -0,0 +1,179 @@
"""
Triton kernel for fused selective expert gather + NF4 dequantization.
Instead of:
1. Gather packed uint8 data for active experts (memory copy)
2. Gather absmax for active experts (memory copy)
3. Call BnB dequantize_4bit CUDA kernel
This kernel does all three in one pass:
- Reads packed NF4 bytes from expert-strided positions
- Looks up the NF4 codebook
- Multiplies by the per-block absmax
- Writes bf16 output directly
This eliminates the intermediate gather buffer entirely.
"""
import torch
import triton
import triton.language as tl
# NF4 codebook (16 values, precomputed by BnB)
# These are the normalized float4 reconstruction values
NF4_CODEBOOK = [
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0,
]
@triton.jit
def _selective_dequant_nf4_kernel(
# Input: packed NF4 data (flattened, expert-major order)
packed_ptr,
# Input: absmax values (flattened, expert-major order)
absmax_ptr,
# Input: active expert indices
active_experts_ptr,
# Input: NF4 codebook (16 float values)
codebook_ptr,
# Output: dequantized bf16 weights [num_active, expert_numel]
out_ptr,
stride_out_e, # stride for expert dim in output
# Dimensions
num_active,
packed_per_expert, # expert_numel // 2
blocks_per_expert, # expert_numel // blocksize
blocksize: tl.constexpr,
# Tile size
BLOCK_SIZE: tl.constexpr, # elements per thread block (must be multiple of 2)
):
"""
Each program processes BLOCK_SIZE elements from one expert.
Grid: (num_active, cdiv(expert_numel, BLOCK_SIZE))
For each output element:
1. Compute which byte in packed data contains this element
2. Extract the 4-bit nibble (high or low)
3. Look up in NF4 codebook
4. Scale by absmax for this block
"""
expert_local_idx = tl.program_id(0) # which active expert (0..num_active-1)
block_id = tl.program_id(1) # which element block
# Load the global expert index
expert_global = tl.load(active_experts_ptr + expert_local_idx).to(tl.int64)
expert_numel = packed_per_expert * 2 # 2 elements per packed byte
elem_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = elem_offset < expert_numel
# Each element is packed as: byte[i//2], low nibble for even i, high for odd i
byte_idx = elem_offset // 2
is_high = (elem_offset % 2) == 1
# Read packed bytes from the global expert's region
packed_global_offset = expert_global * packed_per_expert + byte_idx
packed_bytes = tl.load(packed_ptr + packed_global_offset, mask=mask, other=0).to(
tl.int32
)
# Extract 4-bit nibble
# BnB packing: high nibble = even element, low nibble = odd element
nibble = tl.where(is_high, packed_bytes & 0xF, (packed_bytes >> 4) & 0xF)
# NF4 codebook lookup
# Load all 16 codebook values (small, fits in registers)
# Use gather from codebook pointer
code_val = tl.load(codebook_ptr + nibble, mask=mask, other=0.0)
# Load absmax for this element's quantization block
block_idx = elem_offset // blocksize
absmax_global_offset = expert_global * blocks_per_expert + block_idx
absmax_val = tl.load(absmax_ptr + absmax_global_offset, mask=mask, other=1.0)
# Dequantize: value = codebook[nibble] * absmax
result = code_val * absmax_val
# Store to output
out_offset = expert_local_idx * stride_out_e + elem_offset
tl.store(out_ptr + out_offset, result.to(out_ptr.dtype.element_ty), mask=mask)
def selective_dequant_nf4_triton(
packed_data: torch.Tensor,
absmax: torch.Tensor,
active_experts: torch.Tensor,
expert_shape: tuple[int, int],
blocksize: int,
dtype: torch.dtype = torch.bfloat16,
codebook: torch.Tensor | None = None,
) -> torch.Tensor:
"""Fused selective gather + NF4 dequantization via Triton kernel.
Args:
packed_data: Flattened packed NF4 data [total_packed] or [total_packed, 1]
absmax: Per-block scaling factors [total_blocks]
active_experts: Sorted indices of experts to dequantize [num_active]
expert_shape: (dim1, dim2) per expert
blocksize: Quantization block size
dtype: Output dtype (default bf16)
codebook: NF4 lookup table [16] (uses default NF4 codebook if None)
Returns:
Dequantized weights [num_active, dim1, dim2]
"""
num_active = active_experts.shape[0]
expert_numel = expert_shape[0] * expert_shape[1]
packed_per_expert = expert_numel // 2
blocks_per_expert = expert_numel // blocksize
# Prepare codebook on device
if codebook is None:
codebook = torch.tensor(
NF4_CODEBOOK, dtype=torch.float32, device=packed_data.device
)
else:
codebook = codebook.to(device=packed_data.device, dtype=torch.float32)
# Flatten inputs
packed_flat = packed_data.reshape(-1)
absmax_flat = absmax.reshape(-1).float() # absmax is usually fp32
# Output buffer
out = torch.empty(num_active, expert_numel, dtype=dtype, device=packed_data.device)
BLOCK_SIZE = 1024 # Process 1024 elements per thread block
grid = (num_active, triton.cdiv(expert_numel, BLOCK_SIZE))
_selective_dequant_nf4_kernel[grid](
packed_flat,
absmax_flat,
active_experts,
codebook,
out,
out.stride(0),
num_active=num_active,
packed_per_expert=packed_per_expert,
blocks_per_expert=blocks_per_expert,
blocksize=blocksize,
BLOCK_SIZE=BLOCK_SIZE,
)
return out.reshape(num_active, *expert_shape)

View File

@@ -61,9 +61,20 @@ class KernelsPlugin(BasePlugin):
return "axolotl.integrations.kernels.KernelsArgs"
def pre_model_load(self, cfg):
from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK
# Prefer text backbone type for VLMs, but fall back to base type
# when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text)
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
if (
moe_model_type not in SPARSE_MOE_BLOCK
and cfg.model_config_type in SPARSE_MOE_BLOCK
):
moe_model_type = cfg.model_config_type
if cfg.use_scattermoe:
self._register_kernels()
self._kernelize_model(cfg.model_config_type)
self._kernelize_model(moe_model_type)
elif cfg.use_sonicmoe:
if not importlib.util.find_spec("sonicmoe"):
raise RuntimeError(
@@ -75,11 +86,9 @@ class KernelsPlugin(BasePlugin):
from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe
LOG.info(
f"Applying SonicMoE patches for model type: {cfg.model_config_type}"
)
LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}")
patch_sonicmoe(
cfg.model_config_type,
moe_model_type,
torch_compile=bool(getattr(cfg, "torch_compile", False)),
)
@@ -110,6 +119,16 @@ class KernelsPlugin(BasePlugin):
}
)
def add_callbacks_pre_trainer(self, cfg, model):
callbacks = []
if cfg.use_scattermoe:
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
callbacks.append(AutotuneReportCallback())
return callbacks
def _kernelize_model(self, model_type: str):
from kernels import replace_kernel_forward_from_hub

View File

@@ -5,6 +5,7 @@ Different MoE architectures use different routing strategies:
- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization)
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None)
- glm_moe_dsa: sigmoid -> topk (with group-based expert selection)
- mistral4: softmax -> group selection -> topk (with renormalization and scaling)
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
@@ -45,6 +46,8 @@ def get_model_moe_config(model_type: str):
"minimax",
):
return softmax_topk_routing, ActivationType.SWIGLU, "gate"
elif model_type in ("mistral4",):
return softmax_group_topk_routing, ActivationType.SWIGLU, "gate"
elif model_type in (
"glm_moe_dsa",
"deepseek_v3",
@@ -126,6 +129,62 @@ def softmax_topk_routing(
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
def softmax_group_topk_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale."""
gate = moe_block.gate
T, H = hidden_states.shape
K = moe_block.top_k
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
n_group = getattr(moe_block, "n_group", 1)
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
scores_for_choice = router_probs
# Group selection: pick top groups, mask the rest
if n_group > 1:
group_scores = (
scores_for_choice.view(-1, n_group, E // n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
)
group_idx = torch.topk(
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
topk_weights = router_probs.gather(1, topk_indices)
# Renormalization + scaling
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
if norm_topk_prob:
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
topk_weights = topk_weights * routed_scaling_factor
# Flatten for moe_general_routing_inputs
token_indices = (
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
.unsqueeze(1)
.expand(T, K)
)
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
flat_token_idx = token_indices.reshape(-1) # [T*K]
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
def sigmoid_topk_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

View File

@@ -25,7 +25,7 @@ def get_lora_parameters(
) -> tuple[
torch.Tensor,
torch.Tensor | None,
QuantState | None,
QuantState | torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
float | None,
@@ -48,9 +48,13 @@ def get_lora_parameters(
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
quant_state = getattr(W, "quant_state", None)
if quant_state is None and W.dtype == torch.float8_e4m3fn:
quant_state = getattr(base_layer, "weight_scale_inv", None)
return W, b, quant_state, None, None, None
quant_state = getattr(W, "quant_state", None)
if quant_state is None and W.dtype == torch.float8_e4m3fn:
quant_state = getattr(base_layer, "weight_scale_inv", None)
active_adapter = (
proj.active_adapters[0]
@@ -81,7 +85,7 @@ def matmul_lora(
X: torch.Tensor,
W: torch.Tensor,
b: torch.Tensor | None,
W_quant: QuantState | None,
W_quant: QuantState | torch.Tensor | None,
A: torch.Tensor | None,
B: torch.Tensor | None,
s: float | None,
@@ -636,7 +640,9 @@ class LoRA_QKV(torch.autograd.Function):
del q_weight
del q_weight_t
if A_q is not None and B_q is not None:
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
# Stay decomposed: dQ @ B^T gives [T, R], then [T, R] @ (s*A) gives [T, in]
# This is 65x fewer FLOPs than materializing B@A into [out, in]
grad_X.addmm_(torch.mm(q_grad, B_q_scaled), A_q_scaled)
# K path
k_weight_t = dequantize(k_weight, k_quant)
@@ -644,7 +650,7 @@ class LoRA_QKV(torch.autograd.Function):
del k_weight
del k_weight_t
if A_k is not None and B_k is not None:
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
grad_X.addmm_(torch.mm(k_grad, B_k_scaled), A_k_scaled)
# V path
v_weight_t = dequantize(v_weight, v_quant)
@@ -652,7 +658,7 @@ class LoRA_QKV(torch.autograd.Function):
del v_weight
del v_weight_t
if A_v is not None and B_v is not None:
grad_X.addmm_(v_grad, torch.mm(B_v_scaled, A_v_scaled))
grad_X.addmm_(torch.mm(v_grad, B_v_scaled), A_v_scaled)
# Transpose gradients if needed
if d_A_q is not None:
@@ -815,7 +821,8 @@ class LoRA_O(torch.autograd.Function):
del W
A, B = A.to(dtype), B.to(dtype)
dX += s * dY @ B @ A
# Stay decomposed: dY @ B gives [T, R], then [T, R] @ A gives [T, in]
dX.addmm_(torch.mm(dY, B), A, alpha=s)
# W, b, W_quant, A, B, s
return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None

View File

@@ -1,4 +1,4 @@
"""Dequantization utilities for `bitsandbytes` integration."""
"""Dequantization utilities for `bitsandbytes` and FP8 integration."""
import ctypes
@@ -15,9 +15,50 @@ CUDA_STREAM: torch.cuda.Stream | None = None
HAS_CUDA_STREAM: bool = Version(bnb.__version__) > Version("0.43.3")
def dequantize_fp8(
W: torch.Tensor,
scale_inv: torch.Tensor,
dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
"""Dequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv.
Args:
W: FP8 weight tensor [out_features, in_features] in float8_e4m3fn.
scale_inv: Per-block inverse scale [ceil(out/block), ceil(in/block)]
or per-tensor scalar.
dtype: Output dtype (default bf16).
Returns:
Dequantized tensor in the specified dtype.
"""
W_float = W.to(dtype)
if scale_inv.numel() == 1:
return W_float * scale_inv.to(dtype)
if scale_inv.dim() == 2 and W.dim() == 2:
sr, sc = scale_inv.shape
br = W.shape[0] // sr
bc = W.shape[1] // sc
# If dimensions are exactly divisible, use fast reshape path
if sr * br == W.shape[0] and sc * bc == W.shape[1]:
return (
W_float.reshape(sr, br, sc, bc) * scale_inv[:, None, :, None].to(dtype)
).reshape(W.shape)
# Tail-block handling: compute actual block size (ceil division),
# tile scale_inv to cover full shape, then crop to W's dimensions
br_ceil = -(-W.shape[0] // sr) # ceil(rows / scale_rows) = block_size
bc_ceil = -(-W.shape[1] // sc)
scale_expanded = (
scale_inv.to(dtype)
.repeat_interleave(br_ceil, dim=0)
.repeat_interleave(bc_ceil, dim=1)
)[: W.shape[0], : W.shape[1]]
return W_float * scale_expanded
return W_float * scale_inv.to(dtype)
def dequantize(
W: torch.Tensor,
quant_state: QuantState | list | None = None,
quant_state: QuantState | list | torch.Tensor | None = None,
out: torch.Tensor | None = None,
) -> torch.Tensor:
"""
@@ -49,6 +90,15 @@ def dequantize(
if quant_state is None:
return W
# FP8 path: quant_state is actually scale_inv tensor
if W.dtype == torch.float8_e4m3fn:
scale_inv = quant_state
# Caller may pass W.t() (non-contiguous) — dequantize in original
# layout then transpose back so the result shape matches the input.
if not W.is_contiguous() and W.dim() == 2:
return dequantize_fp8(W.t(), scale_inv).t()
return dequantize_fp8(W, scale_inv)
# Get the target device from input tensor W
target_device = W.device

View File

@@ -160,6 +160,18 @@ def load_lora(
else:
model = get_peft_model(model, lora_config, **model_kwargs)
# FP8 models: LoRA A/B inherit FP8 dtype from base weights, but training
# requires a compute dtype (bf16/fp16). Cast trainable LoRA params.
if cfg.torch_dtype:
_fp8_cast_dtype = cfg.torch_dtype
elif torch.cuda.is_available() and torch.cuda.is_bf16_supported():
_fp8_cast_dtype = torch.bfloat16
else:
_fp8_cast_dtype = torch.float16
for _name, param in model.named_parameters():
if param.requires_grad and param.dtype == torch.float8_e4m3fn:
param.data = param.data.to(_fp8_cast_dtype)
if rank == 0:
try:
model.print_trainable_parameters()

View File

@@ -215,6 +215,8 @@ class ModelLoader:
self.model_kwargs["revision"] = self.cfg.revision_of_model
if self.cfg.use_kernels:
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
if "allow_all_kernels" not in self.model_kwargs:
self.model_kwargs["allow_all_kernels"] = self.cfg.use_kernels
self._set_quantization_config()
self._set_attention_config()
self._check_model_requirements()
@@ -503,6 +505,20 @@ class ModelLoader:
elif not is_ds_zero3:
self.model_kwargs["device_map"] = device_map
# quantize_moe_experts quantizes expert weights on-the-fly during loading,
# so the actual VRAM usage is much less than bf16 estimates.
# When device_map is "auto", accelerate's infer_auto_device_map computes
# the device map at bf16 size (before quantization), causing it to offload
# layers to CPU, which BnB then rejects. Force single-GPU placement to
# prevent this. Only applies to the non-FSDP, non-ZeRO3 path (DDP/single).
if getattr(self.cfg, "quantize_moe_experts", False) and device_map in (
"auto",
None,
):
self.model_kwargs["device_map"] = {
"": int(os.environ.get("LOCAL_RANK", 0))
}
cur_device = get_device_type()
if "mps" in str(cur_device):
self.model_kwargs["device_map"] = "mps:0"
@@ -829,8 +845,9 @@ class ModelLoader:
def _set_z3_leaf_modules(self):
from deepspeed.utils import set_z3_leaf_modules
if self.cfg.model_config_type in MOE_ARCH_BLOCK:
moe_blocks = MOE_ARCH_BLOCK[self.cfg.model_config_type]
moe_type = self.cfg.model_config_type_text or self.cfg.model_config_type
if moe_type in MOE_ARCH_BLOCK:
moe_blocks = MOE_ARCH_BLOCK[moe_type]
moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks
set_z3_leaf_modules(
self.model,

View File

@@ -93,11 +93,13 @@ class PatchManager:
def apply_pre_model_load_patches(self):
"""Apply pre-model load patches based on config."""
self._deactivate_hf_async_load()
self._apply_transformers_patches()
# self._apply_flex_attention_patches()
self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch()
self._apply_sageattn_patches()
self._apply_flash_attn_4_patches()
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_model_specific_patches()
@@ -114,6 +116,8 @@ class PatchManager:
self._apply_patch_deepspeed_zero3()
self._apply_voxtral_patches()
self._apply_apertus_patches()
self._apply_trl_vllm_patches()
self._apply_trl_trainer_utils_patches()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
@@ -227,6 +231,15 @@ class PatchManager:
patch_sageattn()
def _apply_flash_attn_4_patches(self):
"""Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+."""
if not self.cfg.flash_attention:
return
from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4
patch_flash_attn_4(self.model_config)
def _apply_model_specific_patches(self):
"""Apply patches specific to model architectures."""
if (
@@ -409,17 +422,27 @@ class PatchManager:
if self.cfg.load_in_8bit:
apply_linear8bitlt_save_patch()
def _deactivate_hf_async_load(self):
"""Load weights synchronously so they can be converted and not OOM."""
if self.cfg.load_in_4bit or self.cfg.load_in_8bit:
os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1"
def _apply_moe_expert_quantization_patch(self):
"""Patch transformers weight loading to quantize MoE expert params on-the-fly."""
if not self.cfg.quantize_moe_experts:
"""Patch transformers weight loading and PEFT for MoE expert quantization."""
has_target_params = bool(getattr(self.cfg, "lora_target_parameters", None))
if not self.cfg.quantize_moe_experts and not has_target_params:
return
from axolotl.monkeypatch.moe_quant import (
patch_moe_quantization_on_load,
patch_peft_target_parameters_matching,
)
patch_moe_quantization_on_load(self.cfg)
if self.cfg.quantize_moe_experts:
from axolotl.monkeypatch.moe_quant import patch_moe_quantization_on_load
patch_moe_quantization_on_load(self.cfg)
patch_peft_target_parameters_matching()
def _finalize_moe_expert_quantization(self, model: PreTrainedModel):
@@ -548,15 +571,6 @@ class PatchManager:
LOG.info("Patching with xformers attention...")
hijack_llama_attention()
def _patch_llama_sample_packing(self):
"""Apply sample packing patches for LLaMA models."""
from axolotl.monkeypatch.llama_patch_multipack import (
hijack_llama_prepare_4d_mask,
)
LOG.info("Patching llama _prepare_4d_causal_attention_mask*...")
hijack_llama_prepare_4d_mask()
def _patch_llama_derived_model(self):
"""Modify all llama derived models in one block."""
if self.cfg.is_llama_derived_model and not (
@@ -568,8 +582,6 @@ class PatchManager:
self._patch_llama_flash_attention()
elif self.cfg.xformers_attention:
self._patch_llama_xformers_attention()
elif self.cfg.sample_packing:
self._patch_llama_sample_packing()
elif self.cfg.s2_attention:
raise NotImplementedError(
"Shifted-sparse attention not currently implemented without flash attention."
@@ -646,6 +658,50 @@ class PatchManager:
patch_apertus_xielu_activation()
def _apply_trl_vllm_patches(self):
"""Apply TRL vLLM patches for batched weight sync, NaN logprobs fix, and scalar handling."""
if (
self.cfg.rl
and getattr(self.cfg, "trl", None)
and getattr(self.cfg.trl, "use_vllm", False)
):
from axolotl.monkeypatch.trainer.trl_vllm import patch_trl_vllm
patch_trl_vllm()
def _apply_trl_trainer_utils_patches(self):
"""Replace trl.trainer.utils.{selective_log_softmax, entropy_from_logits} with Triton kernels."""
if not self.cfg.rl:
return
try:
from axolotl.monkeypatch.trainer.utils import (
entropy_from_logits,
selective_log_softmax,
)
except (ImportError, ModuleNotFoundError):
LOG.warning("Triton not available — skipping trl.trainer.utils patches")
return
import trl.trainer.utils
# Guard against repeated calls: only stash the original if trl still
# points at its own implementation (not our wrapper).
if trl.trainer.utils.selective_log_softmax is not selective_log_softmax:
from axolotl.monkeypatch.trainer import utils as _axolotl_trainer_utils
_axolotl_trainer_utils.selective_log_softmax_original = (
trl.trainer.utils.selective_log_softmax
)
trl.trainer.utils.selective_log_softmax = selective_log_softmax
if trl.trainer.utils.entropy_from_logits is not entropy_from_logits:
trl.trainer.utils.entropy_from_logits = entropy_from_logits
LOG.info(
"Patched trl.trainer.utils with Triton selective_log_softmax and entropy_from_logits"
)
def _apply_scaling_softmax_patch(self, model: PreTrainedModel):
"""Apply Scaling Softmax (SSMax) patch. Ref: https://arxiv.org/abs/2501.19399"""
if self.cfg.scaling_softmax:

View File

@@ -55,12 +55,12 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
)
processor_kwargs["trust_remote_code"] = cfg.trust_remote_code or False
processor_kwargs["tokenizer"] = tokenizer
processor = processor_cls.from_pretrained(
cfg.processor_config,
**processor_kwargs,
)
processor.tokenizer = tokenizer
# Attempt to load image size from processor if available
if (

View File

@@ -221,6 +221,14 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
# Generic fallback: if tokenizer still has no pad_token, use eos_token
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
LOG.warning(
"Tokenizer does not have a pad_token, falling back to eos_token: %s",
tokenizer.eos_token,
)
additional_special_tokens = None
if cfg.special_tokens:
special_tokens = cfg.special_tokens.to_dict()

View File

@@ -78,30 +78,21 @@ def patch_parallelism_config():
def patch_prepare_cp():
import functools
import contextlib
import torch
from accelerate import Accelerator
def patched_prepare_cp(self, *args):
if self.parallelism_config.cp_backend == "deepspeed":
return args
from accelerate.big_modeling import _attach_context_parallel_hooks
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
set_rotate_method(cp_comm_strategy)
self._cp_context = functools.partial(
context_parallel, mesh=self.torch_device_mesh["cp"]
)
for arg in args:
if isinstance(arg, torch.nn.Module):
_attach_context_parallel_hooks(arg)
@contextlib.contextmanager
def _noop_cp_context(
buffers=None, buffer_seq_dims=None, no_restore_buffers=None
):
yield
self._cp_context = _noop_cp_context
return args
Accelerator._prepare_cp = patched_prepare_cp

View File

@@ -0,0 +1,104 @@
"""Transparently upgrade FA2 to FA4 when available on SM90+ hardware."""
import torch
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _get_head_dims(model_config):
"""Extract (head_dim, head_dim_v) from a model config.
Handles composite models (e.g. Qwen3.5 VL) via text_config and
MLA models (DeepSeek/Kimi) that have separate Q/V head dimensions.
"""
cfg = model_config
if hasattr(cfg, "text_config"):
cfg = cfg.text_config
# MLA models: Q head_dim = qk_nope + qk_rope, V head_dim = v_head_dim
if hasattr(cfg, "qk_nope_head_dim") and hasattr(cfg, "qk_rope_head_dim"):
head_dim = cfg.qk_nope_head_dim + cfg.qk_rope_head_dim
head_dim_v = getattr(cfg, "v_head_dim", head_dim)
return head_dim, head_dim_v
# Standard models
if hasattr(cfg, "head_dim"):
return cfg.head_dim, cfg.head_dim
if hasattr(cfg, "hidden_size") and hasattr(cfg, "num_attention_heads"):
head_dim = cfg.hidden_size // cfg.num_attention_heads
return head_dim, head_dim
return None, None
def patch_flash_attn_4(model_config=None):
"""Patch _lazy_imports to redirect FA2 imports to FA4 if available on supported hardware."""
if not torch.cuda.is_available():
return
major, _ = torch.cuda.get_device_capability()
# Matches flash_attn/cute/interface.py: arch / 10 in [9, 10, 11]
if major not in (9, 10, 11):
return
try:
from flash_attn.cute import ( # noqa: F401
flash_attn_func,
flash_attn_varlen_func,
)
except ImportError:
LOG.info(
"Flash Attention 4 is available for your GPU and offers faster training speeds. "
"To enable: pip install flash-attn-4"
)
return
# Validate head dimensions against FA4's own constraints
head_dim = None
if model_config is not None:
head_dim, head_dim_v = _get_head_dims(model_config)
if head_dim is not None:
try:
from flash_attn.cute.interface import _validate_head_dims
except ImportError:
LOG.warning(
"Could not import _validate_head_dims from flash_attn.cute.interface, "
"unable to verify head dimension compatibility, falling back to FA2"
)
return
# alignment = 16 // element_size; bf16/fp16 = 2 bytes -> alignment = 8
alignment = 8
try:
_validate_head_dims(head_dim, head_dim_v, major, alignment)
except AssertionError as exc:
LOG.warning(
"Model head dimensions not supported by FA4, "
"falling back to FA2: %s",
exc,
)
return
import transformers.modeling_flash_attention_utils as fa_utils
if getattr(fa_utils._lazy_imports, "_axolotl_patched", False):
return
def _patched_lazy_imports(
implementation, attention_wrapper=None, allow_all_kernels=False
):
return (
flash_attn_func,
flash_attn_varlen_func,
fa_utils._pad_input,
fa_utils._unpad_input,
)
_patched_lazy_imports._axolotl_patched = True
fa_utils._lazy_imports = _patched_lazy_imports
LOG.info(
"Flash Attention 4 enabled (head_dim=%s)",
head_dim if model_config else "unknown",
)

View File

@@ -64,15 +64,12 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
LOG.info(
"Compiling flex attention with kwargs: %s. This may take a while...",
flex_attn_compile_kwargs,
main_process_only=True,
)
self._compiled_flex_attention = torch.compile(
flex_attention,
**flex_attn_compile_kwargs,
)
LOG.info(
"Flex attention compiled successfully.", main_process_only=True
)
LOG.info("Flex attention compiled successfully.")
self._is_flex_compiled = True

View File

@@ -1,24 +0,0 @@
"""
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
"""
from typing import Optional
import torch
from axolotl.monkeypatch.utils import mask_2d_to_4d
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len)
inverted_mask = 1.0 - masked_zero_one_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
def hijack_expand_mask():
import transformers
transformers.models.llama.modeling_llama._expand_mask = _expand_mask

View File

@@ -1,26 +0,0 @@
"""
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
"""
from axolotl.monkeypatch.utils import (
patched_prepare_4d_causal_attention_mask,
patched_prepare_4d_causal_attention_mask_for_sdpa,
)
def hijack_llama_prepare_4d_mask():
from transformers import modeling_attn_mask_utils
from transformers.models.llama import modeling_llama
modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = (
patched_prepare_4d_causal_attention_mask_for_sdpa
)
modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = (
patched_prepare_4d_causal_attention_mask_for_sdpa
)
modeling_llama._prepare_4d_causal_attention_mask = (
patched_prepare_4d_causal_attention_mask
)
modeling_attn_mask_utils._prepare_4d_causal_attention_mask = (
patched_prepare_4d_causal_attention_mask
)

View File

@@ -51,6 +51,29 @@ QKV_PATCHES = [
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip("\n"),
),
(
"""
query_states, gate = torch.chunk(
self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
)
gate = gate.reshape(*input_shape, -1)
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip("\n"),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states, gate = torch.chunk(
query_states.view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
)
gate = gate.reshape(*input_shape, -1)
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip("\n"),
),
]
ORIGINAL_O_CODE = """
@@ -299,6 +322,8 @@ def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]:
if hasattr(pretrained_model, "language_model"):
return pretrained_model.language_model.layers
if hasattr(pretrained_model, "model"):
if hasattr(pretrained_model.model, "language_model"):
return pretrained_model.model.language_model.layers
return pretrained_model.model.layers
raise NotImplementedError(

View File

@@ -1,13 +1,4 @@
"""
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
In transformers v5, MoE models store expert weights as fused 3D tensors that BnB
skips (only targets nn.Linear). This module patches weight loading to quantize them
on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametrization),
reducing peak VRAM from "all experts in bf16" to "one expert at a time."
"""
import os
"""Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors."""
import bitsandbytes as bnb
import torch
@@ -17,18 +8,20 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# Module-level state for the loading-time quantization patch.
_moe_load_state = {
"count": 0,
"mode": "4bit",
"quant_type": "nf4",
"compress_statistics": True,
"patched": False,
# Module path → param names in definition order, captured before quantization.
# Without this, alphabetical loading order would mismatch merge order.
"expert_param_order": {},
}
class Bnb8bitParametrization(torch.nn.Module):
"""Parametrization that dequantizes int8 row-wise quantized data on access."""
"""Dequantizes int8 row-wise quantized data on access."""
def __init__(self, row_stats: torch.Tensor):
super().__init__()
@@ -36,7 +29,7 @@ class Bnb8bitParametrization(torch.nn.Module):
@torch.no_grad()
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
# Flatten 3D+ to 2D for BnB's dequant, then reshape back.
"""Flatten 3D+ to 2D for BnB's dequant, then reshape back."""
orig_shape = quantized_param.shape
if quantized_param.ndim > 2:
quantized_param = quantized_param.reshape(-1, orig_shape[-1])
@@ -76,14 +69,11 @@ def replace_parameter_8bit(module, param_name):
def patch_moe_quantization_on_load(cfg):
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly.
Wraps ``set_param_for_module`` so that 3D+ CUDA tensors with "expert" in their
name are quantized (4-bit or 8-bit) as they're loaded, keeping peak VRAM low.
"""
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly."""
mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit"
_moe_load_state["mode"] = mode
_moe_load_state["count"] = 0
_moe_load_state["expert_param_order"] = {}
if _moe_load_state["patched"]:
LOG.debug("MoE loading-time quantization patch already active")
@@ -103,14 +93,6 @@ def patch_moe_quantization_on_load(cfg):
_moe_load_state["quant_type"] = quant_type
_moe_load_state["compress_statistics"] = compress_statistics
# Disable async tensor loading. Transformers' convert_and_load_state_dict_in_model
# uses a ThreadPoolExecutor to materialise tensors (move from safetensors → CUDA)
# ahead of time. With MoE models this pre-fetches many large bf16 expert tensors
# onto the GPU simultaneously — long before our set_param_for_module patch can
# quantise and free them one-by-one — causing OOM even at <5 % of weights loaded.
# Sequential loading ensures only ONE bf16 expert tensor is on-GPU at a time.
os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1"
# Disable caching_allocator_warmup — it pre-allocates a huge tensor at bf16
# size for all params, defeating our on-load quantization VRAM savings.
def _noop_warmup(*args, **kwargs):
@@ -123,7 +105,6 @@ def patch_moe_quantization_on_load(cfg):
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
original_set_param(model, target_name, param_value, *args, **kwargs)
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
if param_value.ndim >= 3 and param_value.is_cuda:
mod_path, _, pname = target_name.rpartition(".")
mod = model.get_submodule(mod_path) if mod_path else model
@@ -136,6 +117,13 @@ def patch_moe_quantization_on_load(cfg):
)
return
# Record definition order before parametrizations override it
# with alphabetical order.
if mod_path not in _moe_load_state["expert_param_order"]:
_moe_load_state["expert_param_order"][mod_path] = list(
mod._parameters.keys()
)
if _moe_load_state["mode"] == "4bit":
replace_parameter_4bit(
mod,
@@ -161,20 +149,28 @@ def get_moe_quantized_count():
def patch_peft_target_parameters_matching():
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules."""
"""Fix PEFT's _inject_parameters for target_parameters on quantized MoE experts.
1. Expands short suffixes to full module paths for parametrized modules.
2. Iterates params in definition order (not alphabetical order) so saved
adapters are compatible with standard PEFT, vLLM, etc.
"""
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
return
from peft.tuners.tuners_utils import BaseTuner
original_inject = BaseTuner._inject_parameters
from contextlib import nullcontext
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
from peft.utils.integrations import init_empty_weights
from peft.utils.other import _get_submodules
def _patched_inject_parameters(
self, peft_config, model, adapter_name, low_cpu_mem_usage
):
# Patch target_parameters to use full paths for parametrized modules
original_targets = list(peft_config.target_parameters)
expanded = set(original_targets)
# Expand short suffixes to full paths for parametrized modules.
for module_name, module in model.named_modules():
if not hasattr(module, "parametrizations"):
continue
@@ -185,14 +181,74 @@ def patch_peft_target_parameters_matching():
) and hasattr(module, param_name):
expanded.add(f"{module_name}.{param_name}")
peft_config.target_parameters = sorted(expanded)
try:
return original_inject(
self, peft_config, model, adapter_name, low_cpu_mem_usage
)
finally:
peft_config.target_parameters = original_targets
target_names_set = expanded
def strip_base_layer_from_name(module_name):
name = ".base_layer"
while name in module_name:
prefix, _, suffix = module_name.rpartition(name)
module_name = prefix + suffix
return module_name
def create_and_replace_param(module_name, key, param_name):
parent, target, target_name = _get_submodules(model, module_name)
unwrapped_module_name = strip_base_layer_from_name(module_name)
unwrapped_module = model.get_submodule(unwrapped_module_name)
if (
isinstance(unwrapped_module, BaseTunerLayer)
and unwrapped_module.__class__.__name__ != "ParamWrapper"
):
raise ValueError(
f"Trying to wrap an `nn.Parameter` of layer "
f"'{unwrapped_module_name}' of type "
f"{type(target).__name__}, which is not a valid target. "
f"Make sure that this layer is not also targeted with "
f"`target_modules`."
)
self._check_target_module_compatiblity(peft_config, model, target_name)
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
self._create_and_replace(
peft_config,
adapter_name,
target,
target_name,
parent,
current_key=key,
parameter_name=param_name.rpartition(".")[-1],
)
# Use definition order (not alphabetical order) for parametrized modules
# so ParamWrapper nesting matches vanilla PEFT on a plain model.
expert_param_order = _moe_load_state.get("expert_param_order", {})
for module_name, module in model.named_modules():
if hasattr(module, "parametrizations"):
stored_order = expert_param_order.get(module_name)
if stored_order is not None:
params_iter = [
p for p in stored_order if p in module.parametrizations
]
else:
# Fallback for paths that bypass model loading (e.g. unit tests).
params_iter = list(module.parametrizations.keys())
for param_name in params_iter:
key = f"{module_name}.{param_name}"
if (key in target_names_set) or any(
key.endswith(f".{t}") for t in target_names_set
):
create_and_replace_param(module_name, key, param_name)
self.targeted_parameter_names.append(key)
else:
unwrapped_module_name = strip_base_layer_from_name(module_name)
for param_name, _ in module.named_parameters(recurse=False):
key = f"{unwrapped_module_name}.{param_name}"
if (key in target_names_set) or any(
key.endswith(f".{t}") for t in target_names_set
):
create_and_replace_param(module_name, key, param_name)
self.targeted_parameter_names.append(key)
BaseTuner._inject_parameters = _patched_inject_parameters
patch_peft_target_parameters_matching._axolotl_patched = True
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")
LOG.info("Patched PEFT _inject_parameters for consistent ParamWrapper ordering")

View File

@@ -57,7 +57,9 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"olmo3",
"ministral",
"ministral3",
"mistral4",
"afmoe",
"nemotron",
]

View File

@@ -154,7 +154,6 @@ def register_ring_attn_from_device_mesh(
LOG.info(
f"Enabling ring attention sequence parallelism using DeviceMesh "
f"dimension '{context_parallel_dim}'",
main_process_only=True,
)
# Extract the sequence parallel submesh

View File

@@ -85,7 +85,6 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
mlp_cls._tiled_mlp_dist_impl = None
LOG.info(
f"Successfully monkey-patched TiledMLP for model_type: {model_type}",
main_process_only=True,
)
except (ImportError, AttributeError) as e:
raise RuntimeError(

Some files were not shown because too many files have changed in this diff Show More