Compare commits

...

35 Commits

Author SHA1 Message Date
NanoCode012
255c5b90ca fix: make prepare_context_parallel_inputs no-op 2026-03-20 16:30:58 +07:00
Lorenzo Baraldi
038ffe3f26 fix: solved double sequence partition from SequenceParallelContextManager and Accelerate's native CP (#3498) 2026-03-20 16:27:24 +07:00
VED
c13cb7c853 feat: add nemotron config (#3506)
* nemotron config exp

* Update examples/nemotron/nemotron-mini-4b-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2026-03-20 16:23:42 +07:00
VED
b3823cc6b0 fix: gemma3 configs (#3500) [skip ci]
* gemma fft , text fix

* good lint
2026-03-20 16:14:06 +07:00
VED
113d275bd9 qwen docs + new config (#3499) [skip ci]
* qwen docs + new config

* docss lint

* simplify comments

* read me

* lint comments

* Update docs/multimodal.qmd

* Update docs/multimodal.qmd

* Update examples/qwen3.5/9b-fft-vision.yaml

* chore: fix link and incorrect points

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-03-20 16:13:34 +07:00
VED
7920fe74ec fix num_labels= 1 test fail (#3493) [skip ci]
* trl_num_lables=1

* casual num_lables=1,rwd model

* lint
2026-03-20 16:12:23 +07:00
Wing Lian
1fc86d5295 Scattermoe LoRA optimizations (#3513)
* optimize moe + lora

* more scattermoe optims

* selective dequant

* add correctness unit tests and benchmarks for scattermoe + lora

* handle base+lora split kernel for older moe models

* chore: lint

* fix casting for H200 and B200

* register pressure estimation and pruning for h200/b200

* use soft limit for pruning

* qkv patch for qwen3.5moe

* support text_model for qwen3.5 moe

* nesting of qwen3

* use udpated cce with zero3 support

* Fix decomposed backward for QKV and O projections

eliminates B @ A materialization in LoRA attention backward, replacing full [out, in] matmuls with two small [T, R] matmuls.
2026-03-19 23:07:42 -04:00
Wing Lian
bb483ad4c4 make the CI fail GitHub Actions on test failures (#3517)
* make the CI fail GitHub Actions on test failures

* use model bundle

* install zstd for compressed model artifact
2026-03-19 08:29:24 -04:00
Wing Lian
163bd4dd5a use custom triton kernels for entropy from logits and selective softmax (#3510)
* use custom triton kernels for entropy from logits and selective softmax

* PR comments fixes

* fix out of bounds, include tests, include benchmarks

* chore: lint
2026-03-19 02:02:43 -04:00
Wing Lian
f291ac029c fix for flaky tests in lora ops kernels w autotune (#3511) [skip ci]
* fix for flaky tests in lora ops kernels w autotune

* attempt 2 to fix
2026-03-19 01:18:47 -04:00
Wing Lian
5ef3f28340 Support for Async GRPO (#3486)
* async grpo support

* implement data producer

* use fast async

* handle call to create data producer

* fix liger kernel setup

* fix replay buffer

* chore: lint

* make gpus go brrr

* chore: lint

* inplace div_, unwrap model for logits in bf16

* fuse selective softmax and empty cuda cache on each scoring step

* remove waiting for synch time and fix race

* make fp8 work and allow lora kernels w rl

* grpo with lora vllm sync and fixes for sharded distributed

* update docs

* more patches so it works against trl main

* address PR feedback for corerabbit
2026-03-17 11:42:47 -04:00
Aarush
999b3fec2e fix: replace shell=True subprocess with argument list in modal CLI (#3487)
* fix: replace shell=True subprocess with argument list in modal CLI

Using shell=True with a formatted string containing docker_image
(a user-controlled value) is a command injection risk (Bandit B602).
Replace with an argument list, which passes args directly to the
process without shell interpretation, removing the nosec annotation.

* fix: add nosec annotation to suppress bandit B603/B607 warnings

Removing shell=True (B602) surfaces B603 (subprocess without shell)
and B607 (partial executable path for 'docker'). Use bare # nosec
to suppress both, consistent with other nosec usages in the codebase.
2026-03-17 08:53:13 -04:00
Wing Lian
8f3fb517b3 consolidate behavioud of routing in scattermoe kernels (#3475)
* consolidate behavioud of routing in scattermoe kernels

* collect telemetry on best chosen autotuned kernel

* properly collect data

* Fix property name and get smem too

* handle issues raised by coderabbit

* add tests for parity before refactoring
2026-03-16 23:47:40 -04:00
Wing Lian
830e9f7eaf automatically enable tf32 if supported (#3473) [skip ci]
* automatically enable tf32 if supported

* update fixtures

* handle only when True

* Address CR comments

* address readability from pr comment

* simplify
2026-03-16 23:47:00 -04:00
NanoCode012
d230cbbde3 chore(doc): update readme (#3503) [skip ci] 2026-03-17 09:43:24 +07:00
NanoCode012
a098df527b feat: add Mistral Small 4 (#3502)
* feat: add mistral small 4

* fix: update mistral common

* fix: deepcopy when passing in tokenizer

* feat: add doc on reasoning and thinking section

* fix: don't use custom tokenizer and quantize experts

* chore: update docs and configs

* chore: update doc to follow official name

* feat: update cce to include mistral4

* chore: move

* fix: naming

* fix: test mock breaking get_text_config check

* fix: enable CCE and add expert block targetting to configs

* chore: docs

* fix: use act checkpointing

* chore: doc

* chore: docs

* chore: docs
2026-03-17 09:39:05 +07:00
NanoCode012
7da5f94379 feat: add FA4 (#3481)
* feat: add FA4

* chore: update docs

* fix: recommend FA4 for those with compatible devices

* fix: adjust import check and add head_dim check

* chore: add limitation to doc

* fix: log warning and quit if cannot import validator

* chore: simplify

* fix: add caveat with FA2 shadow dir
2026-03-16 00:13:18 -04:00
NanoCode012
4a5876df7a fix: explicit set workflow permission and move secrets to necessary (#3484) [skip ci]
* fix: explicit set workflow permission and move secrets to necessary
steps only

* fix: comment

* fix: more permission restrict

* chore: add read for pypi
2026-03-16 00:13:05 -04:00
Aarush
defee62d99 fix: fix CONTRIBUTING.md placeholders, bare except clauses, and add convert.py tests (#3485) [skip ci]
* docs: fix codestyle placeholders in CONTRIBUTING.md

Replace unresolved {codestyle} and {URLofCodestyle} template
variables with Ruff, the project's actual linter/formatter
as configured in .pre-commit-config.yaml.

* fix: replace bare except clauses with specific exception types

- quantization.py: use except ImportError for optional torchao imports
  (consistent with line 48 which already uses ImportError correctly)
- cli/config.py: use except (RuntimeError, AssertionError) for CUDA
  device property query

Prevents masking unrelated errors like KeyboardInterrupt or SystemExit.

* test: add unit tests for convert.py JSON/JSONL utilities

Cover FileReader, FileWriter, StdoutWriter, JsonParser,
JsonlSerializer, and JsonToJsonlConverter with 8 test cases
including roundtrip and edge case (empty list) scenarios.

Previously this module had zero test coverage.

* fix: address CodeRabbit review feedback

- quantization.py: catch (ImportError, RuntimeError) for optional
  torchao imports; CUDA wheel/GPU mismatches raise RuntimeError,
  not ImportError
- convert.py: remove unused output_file_path parameter from
  JsonToJsonlConverter.convert() — FileWriter already holds the
  output path from construction
- tests/test_convert.py: update call site to match new signature
2026-03-16 00:12:40 -04:00
VED
f56efdb4ab fix: high eval loss w/ sample packing (#3478) [skip ci]
* check if eval_sp

* radable condition
2026-03-15 22:11:23 -04:00
NanoCode012
d8a646c80d chore: logging cleanup (#3482) [skip ci] 2026-03-15 22:10:57 -04:00
VED
a806704e94 moe quant patch for merge miss match (#3483)
* moe quant patch for merge miss match

* lint

* revert test + fix moe patch

* comment fixxes

* e2e tests

* mismatch fixx tested

* mis match fix wwith vllm compatablity + test

* comment lint

* fix: missing os import, duplicate no op

* chore: simplify comments

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-03-15 22:10:30 -04:00
Wing Lian
d8a05744d7 Reverts commits 79908b3c6, 083c5a042, e1ff75624, ff77fa248. (#3496)
The non-root user approach had multiple issues with RunPod
compatibility, sudo PATH handling, and tmux in exec sessions.
Restoring root as the default user for now.
2026-03-13 11:54:09 -04:00
Wing Lian
ff77fa2488 preserve env for root -> ubuntu user (#3495) 2026-03-13 10:19:34 -04:00
Wing Lian
e1ff756245 become the ubuntu user when root logs in (#3494) 2026-03-13 09:06:54 -04:00
Wing Lian
083c5a0421 check ubuntu user and set uv python dir (#3492) 2026-03-12 23:20:54 -04:00
Wing Lian
79908b3c6e use ubuntu user instead of root for uv docker images (#3491) 2026-03-12 20:41:13 -04:00
Wing Lian
819b157c7b swap around what we're building for docker (#3490)
* remove cloud configuration we don't base image for

* but we do want it for uv
2026-03-11 21:45:13 -04:00
Wing Lian
fccc712dae builds for py312-cu128-torch2.9.1 (#3489) 2026-03-11 20:09:03 -04:00
NanoCode012
23ad40bdd5 fix: disable async load when loading quantized bnb 2026-03-11 13:18:27 +07:00
NanoCode012
cf4d550c88 fix: reduce permissions for preview docs CI (#3480) [skip ci] 2026-03-09 08:04:31 -04:00
Wing Lian
43b1c80aa6 load weights synchronously so they can be converted and not OOM: (#3477) 2026-03-07 07:09:24 -05:00
Wing Lian
a36aaa70ce add gpu tests for scattermoe (#3474) [skip ci] 2026-03-07 00:00:48 -05:00
Wing Lian
80f7088ad1 update setuptools so trl can be installed from main for nightlies (#3471)
* update setuptools so trl can be installed from main for nightlies

* run the nightly in the PR CI on change

* use range request, don't use cu129 in CI since it's not supported with AO

* run multigpu ci if CCE install script changes
2026-03-06 14:59:25 -05:00
Wing Lian
46b9f40f2a bump dev version to 0.16.0.dev0 (#3472) [skip ci] 2026-03-06 14:59:00 -05:00
116 changed files with 14784 additions and 526 deletions

View File

@@ -68,7 +68,7 @@ You can skip certain CI checks by including specific keywords in your commit mes
### Code Style
axolotl uses [{codestyle}]({URLofCodestyle}) as its code style guide. Please ensure that your code follows these guidelines.
axolotl uses [Ruff](https://docs.astral.sh/ruff/) as its code style guide. Please ensure that your code follows these guidelines.
Use the pre-commit linter to ensure that your code is formatted consistently.
```bash
@@ -83,6 +83,6 @@ Write clear and concise commit messages that briefly describe the changes made i
- [GitHub Help](https://help.github.com/)
- [GitHub Pull Request Documentation](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests)
- [{codestyle}]({URLofCodestyle})
- [Ruff](https://docs.astral.sh/ruff/)
Thank you once again for your interest in contributing to axolotl. We look forward to collaborating with you and creating an even better project together!

View File

@@ -15,6 +15,9 @@ on:
- '.github/workflows/base.yml'
workflow_dispatch:
permissions:
contents: read
jobs:
build-base:
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
@@ -124,7 +127,7 @@ jobs:
images: |
axolotlai/axolotl-base
- name: Login to Docker Hub
uses: docker/login-action@v2
uses: docker/login-action@v3
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
@@ -132,7 +135,7 @@ jobs:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v4
uses: docker/build-push-action@v5
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
@@ -173,6 +176,14 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -239,7 +250,7 @@ jobs:
images: |
axolotlai/axolotl-base-uv
- name: Login to Docker Hub
uses: docker/login-action@v2
uses: docker/login-action@v3
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
@@ -247,7 +258,7 @@ jobs:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v4
uses: docker/build-push-action@v5
with:
context: .
file: ./docker/${{ matrix.dockerfile }}

View File

@@ -13,6 +13,9 @@ on:
- ".pre-commit-config.yaml"
workflow_dispatch:
permissions:
contents: read
jobs:
pre-commit:
name: pre-commit

View File

@@ -8,6 +8,9 @@ on:
- "v*"
workflow_dispatch:
permissions:
contents: read
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
@@ -110,6 +113,12 @@ jobs:
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
@@ -174,6 +183,7 @@ jobs:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: 128
@@ -259,6 +269,7 @@ jobs:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: 128
@@ -266,6 +277,12 @@ jobs:
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 128
@@ -326,6 +343,7 @@ jobs:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: 128

View File

@@ -8,6 +8,7 @@ on:
- 'setup.py'
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
- 'scripts/cutcrossentropy_install.py'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
- 'src/axolotl/utils/distributed.py'
workflow_dispatch:
@@ -19,6 +20,9 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
permissions:
contents: read
env:
MODAL_IMAGE_BUILDER_VERSION: "2025.06"
@@ -35,13 +39,13 @@ jobs:
pytorch: 2.8.0
axolotl_extras: fbgemm-gpu
num_gpus: 2
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras: "fbgemm-gpu"
num_gpus: 2
dockerfile: "Dockerfile-uv.jinja"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras: "fbgemm-gpu"
# num_gpus: 2
# dockerfile: "Dockerfile-uv.jinja"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
@@ -77,8 +81,9 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run -m cicd.multigpu

View File

@@ -5,6 +5,9 @@ on:
schedule:
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
permissions:
contents: read
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}

View File

@@ -5,6 +5,8 @@ on:
- cron: '0 0 1 * *' # Run monthly
workflow_dispatch: # Manual kickoff
permissions: {}
jobs:
auto-update:
runs-on: ubuntu-latest

View File

@@ -14,14 +14,8 @@ on:
- .github/workflows/preview-docs.yml
permissions:
checks: write
contents: write
deployments: write
issues: write
discussions: write
pages: write
contents: read
pull-requests: write
statuses: write
jobs:
preview:

View File

@@ -3,9 +3,11 @@ name: publish pypi
on:
push:
tags:
- 'v*'
- "v*"
workflow_dispatch:
permissions: {}
jobs:
setup_release:
name: Create Release
@@ -28,7 +30,8 @@ jobs:
name: pypi
url: https://pypi.org/p/axolotl
permissions:
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
contents: read
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
steps:
- name: Check out repository code
uses: actions/checkout@v4
@@ -46,7 +49,7 @@ jobs:
- name: Extract tag name
id: tag
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
run: echo "TAG_NAME=$(echo $GITHUB_REF | cut -d / -f 3)" >> "$GITHUB_OUTPUT"
- name: Update version in VERSION file
run: |

View File

@@ -3,6 +3,13 @@ on:
workflow_dispatch:
schedule:
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- '.github/workflows/tests-nightly.yml'
permissions:
contents: read
jobs:
pre-commit:
@@ -27,7 +34,7 @@ jobs:
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
curl -v -H "Range: bytes=0-1023" -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
pytest:
name: PyTest
@@ -35,7 +42,6 @@ jobs:
needs: [prime-cdn-s3-cache]
strategy:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
@@ -60,7 +66,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
pip3 install --upgrade packaging==26.0 setuptools==78.1.1 wheel
- name: Install PyTorch
run: |
@@ -153,8 +159,9 @@ jobs:
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run cicd.e2e_tests
docker-e2e-multigpu-tests:
@@ -195,7 +202,8 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run cicd.multigpu

View File

@@ -28,6 +28,9 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
permissions:
contents: read
env:
TRANSFORMERS_IS_CI: "yes"
@@ -55,7 +58,7 @@ jobs:
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
curl -v -H "Range: bytes=0-1023" -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
pytest:
name: PyTest
@@ -303,9 +306,10 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run cicd.e2e_tests
@@ -371,9 +375,10 @@ jobs:
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run cicd.e2e_tests
@@ -413,7 +418,6 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.cleanup

View File

@@ -30,7 +30,7 @@
## 🎉 Latest Updates
- 2026/03:
- New model support has been added in Axolotl for [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
- New model support has been added in Axolotl for [Mistral Small 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4), [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
- 2026/02:
- [ScatterMoE LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3410) support. LoRA fine-tuning directly on MoE expert weights using custom Triton kernels.
@@ -75,7 +75,7 @@ Features:
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support.
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [SageAttention](https://github.com/thu-ml/SageAttention), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention 2/3/4](https://docs.axolotl.ai/docs/attention.html#flash-attention), [Xformers](https://docs.axolotl.ai/docs/attention.html#xformers), [Flex Attention](https://docs.axolotl.ai/docs/attention.html#flex-attention), [SageAttention](https://docs.axolotl.ai/docs/attention.html#sageattention), [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels), [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.

View File

@@ -1 +1 @@
0.15.0
0.16.0.dev0

208
benchmarks/bench_entropy.py Normal file
View File

@@ -0,0 +1,208 @@
"""Benchmark for entropy_from_logits Triton kernel vs original chunked implementation.
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_entropy.py
"""
import gc
import statistics
import torch
import torch.nn.functional as F
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
V = 151936 # Qwen vocab
WARMUP = 5
BENCH_ITERS = 20
MEM_ITERS = 10
def entropy_from_logits_original(logits: torch.Tensor, chunk_size: int = 128):
"""Original chunked implementation (reference)."""
original_shape = logits.shape[:-1]
num_classes = logits.shape[-1]
flat_logits = logits.reshape(-1, num_classes)
entropies = []
for chunk in flat_logits.split(chunk_size, dim=0):
logps = F.log_softmax(chunk, dim=-1)
chunk_entropy = -(torch.exp(logps) * logps).sum(-1)
entropies.append(chunk_entropy)
return torch.cat(entropies, dim=0).reshape(original_shape)
def _clean_gpu():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.synchronize()
def profile_time(fn, logits, n_iters=BENCH_ITERS):
for _ in range(WARMUP):
out = fn(logits, chunk_size=128)
del out
torch.cuda.synchronize()
times = []
for _ in range(n_iters):
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
out = fn(logits, chunk_size=128)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
del out
return times
def profile_memory(fn, logits, n_iters=MEM_ITERS):
for _ in range(WARMUP):
out = fn(logits, chunk_size=128)
del out
torch.cuda.synchronize()
peaks = []
for _ in range(n_iters):
_clean_gpu()
base = torch.cuda.max_memory_allocated()
out = fn(logits, chunk_size=128)
torch.cuda.synchronize()
peaks.append(torch.cuda.max_memory_allocated() - base)
del out
return [p / 1e6 for p in peaks]
def fmt(values, unit=""):
mean = statistics.mean(values)
std = statistics.stdev(values) if len(values) > 1 else 0.0
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
def benchmark_contiguous():
print("=" * 60)
print(
f"CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
)
print("=" * 60)
configs = [
(1, 2048),
(1, 8192),
(1, 16384),
(4, 4096),
(8, 2048),
(16, 2048),
(16, 4096),
]
for B, L in configs:
mem_gb = B * L * V * 2 / 1e9
if mem_gb > 28:
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
print(f"{'' * 60}")
torch.manual_seed(42)
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
t_orig = profile_time(entropy_from_logits_original, logits)
t_triton = profile_time(entropy_from_logits, logits)
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" TIME (ms):")
print(f" original: {fmt(t_orig, 'ms')}")
print(f" triton: {fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(entropy_from_logits_original, logits)
m_triton = profile_memory(entropy_from_logits, logits)
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" MEMORY (peak overhead):")
print(f" original: {fmt(m_orig, 'MB')}")
print(f" triton: {fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del logits
_clean_gpu()
def benchmark_noncontiguous():
print("\n" + "=" * 60)
print(
f"NON-CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
)
print("=" * 60)
configs = [
(4, 2048, "transpose"),
(4, 8192, "transpose"),
(8, 2048, "transpose"),
(4, 4096, "slice_batch"),
]
for B, L, method in configs:
torch.manual_seed(42)
if method == "transpose":
raw = torch.randn(L, B, V, device="cuda", dtype=torch.bfloat16)
logits_nc = raw.transpose(0, 1)
raw_gb = L * B * V * 2 / 1e9
elif method == "slice_batch":
raw = torch.randn(B * 2, L, V, device="cuda", dtype=torch.bfloat16)
logits_nc = raw[::2]
raw_gb = B * 2 * L * V * 2 / 1e9
else:
continue
if raw_gb > 28:
print(f"\n skip B={B}, L={L}, {method} ({raw_gb:.1f} GB)")
del raw, logits_nc
torch.cuda.empty_cache()
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B}, L={L} {method} ({N} rows, raw {raw_gb:.2f} GB)")
print(f"{'' * 60}")
def original_with_copy(logits, chunk_size=128):
return entropy_from_logits_original(
logits.contiguous(), chunk_size=chunk_size
)
t_orig = profile_time(original_with_copy, logits_nc)
t_triton = profile_time(entropy_from_logits, logits_nc)
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" TIME (ms):")
print(f" orig+copy: {fmt(t_orig, 'ms')}")
print(f" triton-strided:{fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(original_with_copy, logits_nc)
m_triton = profile_memory(entropy_from_logits, logits_nc)
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" MEMORY (peak overhead):")
print(f" orig+copy: {fmt(m_orig, 'MB')}")
print(f" triton-strided:{fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del raw, logits_nc
_clean_gpu()
if __name__ == "__main__":
benchmark_contiguous()
benchmark_noncontiguous()

View File

@@ -0,0 +1,284 @@
"""Benchmark for ScatterMoE LoRA Triton kernels.
Measures forward, backward dX, and backward dA/dB kernels at common MoE
model shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter,
and full fwd+bwd autograd throughput.
Usage:
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B
"""
import argparse
import gc
import time
from functools import partial
import torch
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
lora_ops,
ops as base_ops,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
flatten_sort_count,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
ScatterMoELoRA,
)
DEVICE = "cuda"
DTYPE = torch.bfloat16
WARMUP = 5
ITERS = 20
# ─── Model configs ──────────────────────────────────────────────────────────
BUILTIN_CONFIGS = {
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
"Qwen3-30B-A3B": (128, 2048, 768, 8),
"OLMoE-1B-7B": (64, 2048, 1024, 8),
"Mixtral-8x7B": (8, 4096, 14336, 2),
}
def _resolve_config(spec):
"""Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs."""
key = spec.lower().replace("/", "-")
for name, cfg in BUILTIN_CONFIGS.items():
if key in name.lower() or name.lower() in key:
return name, cfg
from transformers import AutoConfig
hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)
if callable(getattr(hf_cfg, "get_text_config", None)):
tc = hf_cfg.get_text_config()
if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type:
hf_cfg = tc
hidden = hf_cfg.hidden_size
inter = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
experts = (
getattr(hf_cfg, "num_experts", None)
or getattr(hf_cfg, "num_local_experts", None)
or getattr(hf_cfg, "n_routed_experts", None)
)
top_k = (
getattr(hf_cfg, "num_experts_per_tok", None)
or getattr(hf_cfg, "num_experts_per_token", None)
or 2
)
name = spec.split("/")[-1]
return name, (experts, hidden, inter, top_k)
# ─── Benchmark helpers ──────────────────────────────────────────────────────
def _clean():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
def _bench(fn, warmup=WARMUP, iters=ITERS):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
times = []
for _ in range(iters):
torch.cuda.synchronize()
t0 = time.perf_counter()
fn()
torch.cuda.synchronize()
times.append((time.perf_counter() - t0) * 1000)
times.sort()
return times[len(times) // 2]
def _setup(num_experts, K, N, T, top_k, R):
torch.manual_seed(42)
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02
lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01
lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01
logits = torch.randn(T, num_experts, device=DEVICE)
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
sei, ssi, eo = flatten_sort_count(top_idx, num_experts)
gx = base_ops.group(x, ssi, fan_out=top_k)
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy
# ─── Kernel wrappers (avoid B023 loop-variable capture) ──────────────────────
def _call_fwd(x, W, sei, ssi, top_k, lA, lB):
return lora_ops.scatter2scatter_lora(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=top_k,
lora_A=lA,
lora_B=lB,
scaling=2.0,
)
def _call_base(x, W, sei, ssi, top_k):
return base_ops.scatter2scatter(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=top_k,
)
def _call_dx(dy, W, sei, ssi, lA, lB):
return lora_ops.scatter2scatter_lora_dX(
DY=dy,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=1,
lora_A=lA,
lora_B=lB,
scaling=2.0,
dy_grouped=True,
dx_grouped=False,
)
def _call_bwd(dy, gx, lA, lB, eo, num_experts):
return lora_ops.group_bwd_lora(
DY=dy,
X=gx,
lora_A=lA,
lora_B=lB,
expert_offsets=eo,
E=num_experts,
scaling=2.0,
)
# ─── Main ────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark")
parser.add_argument(
"--models",
"-m",
nargs="+",
help="Model names or HF IDs (default: all builtins)",
)
parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64])
parser.add_argument("--seq-len", "-T", type=int, default=2048)
args = parser.parse_args()
T = args.seq_len
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"T={T}, ranks={args.ranks}\n")
if args.models:
configs = [_resolve_config(m) for m in args.models]
else:
configs = list(BUILTIN_CONFIGS.items())
for model_name, (num_experts, hidden, inter, top_k) in configs:
print(f"{'=' * 70}")
print(f" {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}")
print(f"{'=' * 70}")
for R in args.ranks:
for proj, K, N in [("gate_up", hidden, 2 * inter), ("down", inter, hidden)]:
_clean()
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(
num_experts, K, N, T, top_k, R
)
# Forward with LoRA (auto-dispatched: fused or split)
dispatch = (
"split"
if (
num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD
)
else "fused"
)
t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB))
t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k))
t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB))
t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts))
total = t_fwd + t_dx + t_bwd
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
print(
f" R={R:>2} {proj:<8} "
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
f"base={t_base:>6.2f}ms "
f"(+{overhead * 100:.0f}%) "
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
f"total={total:>6.2f}ms"
)
# Full autograd fwd+bwd with memory measurement
x_ag = x.clone().requires_grad_(True)
lA_ag = lA.clone().requires_grad_(True)
lB_ag = lB.clone().requires_grad_(True)
def _run_autograd(
_x=x_ag,
_W=W,
_k=top_k,
_sei=sei,
_ssi=ssi,
_eo=eo,
_lA=lA_ag,
_lB=lB_ag,
):
out = ScatterMoELoRA.apply(
_x,
_W,
_k,
_sei,
_ssi,
_eo,
_lA,
_lB,
2.0,
None,
None,
False,
False,
True,
False,
)
out.sum().backward()
_x.grad = None
_lA.grad = None
_lB.grad = None
t_full = _bench(_run_autograd)
_clean()
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated()
_run_autograd()
torch.cuda.synchronize()
mem_peak = torch.cuda.max_memory_allocated() - mem_before
print(
f" full_fwd_bwd={t_full:>6.2f}ms "
f"peak_delta={mem_peak / 1e6:>6.1f}MB"
)
print()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,191 @@
"""Benchmark for selective_log_softmax Triton kernel vs original implementation.
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_selective_logsoftmax.py
"""
import gc
import statistics
import torch
from axolotl.monkeypatch.trainer.utils import (
selective_log_softmax,
selective_log_softmax_original,
)
V = 151936 # Qwen vocab
WARMUP = 5
BENCH_ITERS = 20
MEM_ITERS = 10
def _clean_gpu():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.synchronize()
def profile_time(fn, args, n_iters=BENCH_ITERS):
for _ in range(WARMUP):
fn(*args)
torch.cuda.synchronize()
times = []
for _ in range(n_iters):
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
fn(*args)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
return times
def profile_memory(fn, args, n_iters=MEM_ITERS):
for _ in range(WARMUP):
out = fn(*args)
del out
torch.cuda.synchronize()
peaks = []
for _ in range(n_iters):
_clean_gpu()
base = torch.cuda.max_memory_allocated()
out = fn(*args)
torch.cuda.synchronize()
peaks.append(torch.cuda.max_memory_allocated() - base)
del out
return [p / 1e6 for p in peaks]
def fmt(values, unit=""):
mean = statistics.mean(values)
std = statistics.stdev(values) if len(values) > 1 else 0.0
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
def benchmark_forward():
print("=" * 60)
print(f"FORWARD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
print("=" * 60)
configs = [
(1, 2048),
(1, 8192),
(4, 4096),
(8, 2048),
(16, 2048),
(16, 4096),
]
for B, L in configs:
mem_gb = B * L * V * 2 / 1e9
if mem_gb > 28:
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
print(f"{'' * 60}")
torch.manual_seed(42)
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
index = torch.randint(0, V, (B, L), device="cuda")
t_orig = profile_time(selective_log_softmax_original, (logits, index))
t_triton = profile_time(selective_log_softmax, (logits, index))
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" TIME (ms):")
print(f" original: {fmt(t_orig, 'ms')}")
print(f" triton: {fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(selective_log_softmax_original, (logits, index))
m_triton = profile_memory(selective_log_softmax, (logits, index))
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" MEMORY (peak overhead):")
print(f" original: {fmt(m_orig, 'MB')}")
print(f" triton: {fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del logits, index
_clean_gpu()
def benchmark_backward():
print("\n" + "=" * 60)
print(f"FWD+BWD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
print("=" * 60)
configs = [
(1, 2048),
(1, 8192),
(4, 4096),
(8, 2048),
(16, 2048),
(16, 4096),
]
def fwd_bwd_original(logits, index):
logits.grad = None
out = selective_log_softmax_original(logits, index)
out.sum().backward()
def fwd_bwd_triton(logits, index):
logits.grad = None
out = selective_log_softmax(logits, index)
out.sum().backward()
for B, L in configs:
mem_gb = B * L * V * 2 / 1e9
if mem_gb > 20:
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB, need room for grads)")
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
print(f"{'' * 60}")
torch.manual_seed(42)
logits_orig = torch.randn(
B, L, V, device="cuda", dtype=torch.bfloat16, requires_grad=True
)
logits_tri = logits_orig.detach().clone().requires_grad_(True)
index = torch.randint(0, V, (B, L), device="cuda")
t_orig = profile_time(fwd_bwd_original, (logits_orig, index))
t_triton = profile_time(fwd_bwd_triton, (logits_tri, index))
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" FWD+BWD TIME (ms):")
print(f" original: {fmt(t_orig, 'ms')}")
print(f" triton: {fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(fwd_bwd_original, (logits_orig, index))
m_triton = profile_memory(fwd_bwd_triton, (logits_tri, index))
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" FWD+BWD MEMORY (peak overhead):")
print(f" original: {fmt(m_orig, 'MB')}")
print(f" triton: {fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del logits_orig, logits_tri, index
_clean_gpu()
if __name__ == "__main__":
benchmark_forward()
benchmark_backward()

View File

@@ -11,7 +11,7 @@ ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace
@@ -31,7 +31,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN uv pip install packaging==26.0 setuptools==75.8.0
RUN uv pip install packaging==26.0 setuptools==78.1.1
RUN uv pip install torchvision
RUN uv pip uninstall causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \

View File

@@ -12,7 +12,7 @@ ENV HF_HOME="{{ HF_HOME }}"
ENV AXOLOTL_DATASET_NUM_PROC="8"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace
@@ -32,7 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN pip install packaging==26.0 setuptools==75.8.0 psutil
RUN pip install packaging==26.0 setuptools==78.1.1 psutil
RUN pip uninstall -y causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \

View File

@@ -3,11 +3,12 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
# curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
hf download "NousResearch/Meta-Llama-3-8B"
hf download "NousResearch/Meta-Llama-3-8B-Instruct"
hf download "microsoft/Phi-4-reasoning"
hf download "microsoft/Phi-3.5-mini-instruct"
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
# hf download "NousResearch/Meta-Llama-3-8B"
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
# hf download "microsoft/Phi-4-reasoning"
# hf download "microsoft/Phi-3.5-mini-instruct"
# hf download "microsoft/Phi-3-medium-128k-instruct"
# Run unit tests with initial coverage report
pytest -v --durations=10 -n8 \

View File

@@ -68,10 +68,6 @@ def run_cmd(cmd: str, run_folder: str):
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
# Propagate errors from subprocess.
try:
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
if exit_code:
print(f"Command '{cmd}' failed with exit code {exit_code}")
return exit_code
except Exception as e: # pylint: disable=broad-except
print(f"Command '{cmd}' failed with exception {e}")
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
if exit_code:
raise RuntimeError(f"Command '{cmd}' failed with exit code {exit_code}")

View File

@@ -13,9 +13,10 @@ sdp_attention: true
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
## Flash Attention 2
## Flash Attention
Uses efficient kernels to compute attention.
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
based on your installed packages and GPU.
```yaml
flash_attention: true
@@ -23,11 +24,9 @@ flash_attention: true
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
### Nvidia
### Flash Attention 2
Requirements: Ampere, Ada, or Hopper GPUs
Note: For Turing GPUs or lower, please use other attention methods.
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
```bash
pip install flash-attn --no-build-isolation
@@ -35,11 +34,12 @@ pip install flash-attn --no-build-isolation
::: {.callout-tip}
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl.
Alternatively, try reinstall or downgrade a version.
:::
#### Flash Attention 3
### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
@@ -50,6 +50,44 @@ cd flash-attention/hopper
python setup.py install
```
### Flash Attention 4
Requirements: Hopper or Blackwell GPUs
```bash
pip install flash-attn-4
```
Or from source:
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/flash_attn/cute
pip install -e .
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
# Remove it so Python can find the real FA4 module:
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute
```
::: {.callout-note}
**Hopper (SM90) users**: The backward kernel is not yet included in the pip package. To use FA4
for training on Hopper, install from source using the instructions above.
:::
::: {.callout-warning}
FA4 only supports head dimensions up to 128 (`d ≤ 128`). The DeepSeek shape `(192, 128)` is
also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions
and falls back to FA2/3.
:::
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
### AMD
Requirements: ROCm 6.0 and above.

View File

@@ -13,12 +13,14 @@ format:
- [Pixtral](#sec-pixtral)
- [Llava-1.5](#sec-llava-15)
- [Mistral-Small-3.1](#sec-mistral-small-31)
- [Mistral-Small-4](#sec-mistral-small-4)
- [Magistral-Small-2509](#sec-magistral-small-2509)
- [Voxtral](#sec-voxtral)
- [Gemma-3](#sec-gemma-3)
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [Qwen3.5](#sec-qwen3-5)
- [GLM-4.6V](#sec-glm-4-6v)
- [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl)
@@ -108,6 +110,12 @@ Please make sure to install vision lib via `pip install 'mistral-common[opencv]=
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
```
### Mistral-Small-4 {#sec-mistral-small-4}
```yaml
base_model: mistralai/Mistral-Small-4-119B-2603
```
### Magistral-Small-2509 {#sec-magistral-small-2509}
::: {.callout-tip}
@@ -184,6 +192,14 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### Qwen3.5 {#sec-qwen3-5}
```yaml
base_model: Qwen/Qwen3.5-9B
chat_template: qwen3_5
```
### GLM-4.6V {#sec-glm-4-6v}
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.

View File

@@ -721,6 +721,213 @@ trl:
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
#### Async GRPO
Async GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step.
```yaml
trl:
use_data_producer: true # Enable data producer protocol
use_vllm: true
async_prefetch: true # Generate rollouts in background thread
prefetch_depth: 1 # Number of rollouts to prefetch
vllm_sync_interval: 2 # Sync weights to vLLM every N steps
```
::: {.callout-note}
Because the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by `vllm_importance_sampling_correction: true` (default when async is enabled).
:::
##### vLLM LoRA Sync
By default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels.
```yaml
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true
trl:
vllm_lora_sync: true # Enable native LoRA sync
```
When `vllm_lora_sync: true` is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual:
```bash
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
```
Then start training on a separate GPU:
```bash
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
```
::: {.callout-tip}
LoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation.
:::
##### Streaming Partial Batch
Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring.
```yaml
trl:
streaming_partial_batch: true
```
##### Importance Sampling Correction
When using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift.
```yaml
trl:
vllm_importance_sampling_correction: true # Enable IS correction
importance_sampling_level: token # 'token' or 'sequence'
off_policy_mask_threshold: 0.5 # Mask sequences with IS ratio below this
```
- `importance_sampling_level: token` applies per-token IS ratios (recommended with Liger kernel)
- `importance_sampling_level: sequence` applies per-sequence IS ratios
- `off_policy_mask_threshold` masks out sequences where the IS ratio indicates they are too far off-policy
##### Replay Buffer
The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches.
```yaml
trl:
replay_buffer_size: 100 # Max cached groups (0 = disabled)
replay_recompute_logps: true # Recompute log-probs for replayed data (recommended)
```
::: {.callout-note}
When `replay_recompute_logps: true` (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data.
:::
##### Deferred Re-rolling
Failed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them.
```yaml
trl:
reroll_start_fraction: 0.5 # Start re-rolling after 50% of training
reroll_max_groups: 1 # Max groups to replace per batch
```
##### Zero-Advantage Batch Skipping
When all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as `skipped_zero_adv_batches=1`.
```yaml
trl:
skip_zero_advantage_batches: true # default
```
##### Parallel Reward Workers
Reward functions that use `signal.alarm()` (e.g., `math_verify`) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation.
```yaml
trl:
reward_num_workers: 4 # Number of subprocess workers (1 = no parallelism)
```
##### Full Async GRPO Example
```yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
gpu_memory_utilization: 0.35
dtype: auto
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true
rl: grpo
trl:
use_data_producer: true
use_vllm: true
async_prefetch: true
prefetch_depth: 1
vllm_sync_interval: 2
vllm_lora_sync: true
streaming_partial_batch: true
vllm_importance_sampling_correction: true
off_policy_mask_threshold: 0.5
importance_sampling_level: token
num_generations: 8
max_completion_length: 512
reward_funcs:
- rewards.accuracy_reward
reroll_start_fraction: 0.5
replay_buffer_size: 100
reward_num_workers: 4
skip_zero_advantage_batches: true
datasets:
- path: AI-MO/NuminaMath-TIR
type: rewards.prompt_transform
split: train
gradient_accumulation_steps: 4
micro_batch_size: 2
max_steps: 500
learning_rate: 1e-5
bf16: true
gradient_checkpointing: true
```
```bash
# Terminal 1: Start vLLM on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
# Terminal 2: Train on GPU 1
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
```
##### Multi-GPU Async GRPO
Async GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs.
**FSDP:**
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
gradient_checkpointing_kwargs:
use_reentrant: false
```
**DeepSpeed ZeRO-3:**
```yaml
deepspeed: deepspeed_configs/zero3_bf16.json
gradient_checkpointing_kwargs:
use_reentrant: true # Required for ZeRO-3
```
```bash
# Terminal 1: Start vLLM on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
# Terminal 2: Train on GPUs 0,1
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --num_processes 2 -m axolotl.cli.train config.yaml
```
::: {.callout-important}
With multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads.
:::
### GDPO
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
]
},
{

View File

@@ -1,8 +1,5 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -27,6 +24,11 @@ datasets:
val_set_size: 0.0
output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
adapter: qlora
lora_r: 32
lora_alpha: 16

View File

@@ -1,8 +1,5 @@
base_model: google/gemma-3-270m-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -27,6 +24,11 @@ datasets:
val_set_size: 0.0
output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
adapter: qlora
lora_r: 32
lora_alpha: 16

View File

@@ -1,9 +1,5 @@
base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
load_in_4bit: true
# gemma3 doesn't seem to play nice with ddp
@@ -24,6 +20,11 @@ dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
adapter: qlora
lora_model_dir:

View File

@@ -0,0 +1,85 @@
# Finetune Mistral Small 4 with Axolotl
Mistral Small 4 is a 119B parameter (6.5B active) multimodal MoE model from MistralAI that unifies instruct, reasoning, and coding capabilities into a single model. It is available on HuggingFace at [Mistral-Small-4-119B-2603](https://huggingface.co/mistralai/Mistral-Small-4-119B-2603).
Thanks to the team at MistralAI for giving us early access to prepare for this release.
## Getting started
Note: Training this model requires weights in BF16 which we will link to later.
Users interested in training can convert / descale the existing FP8 weights.
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
3. Install transformers from main
```bash
pip install git+https://github.com/huggingface/transformers.git
```
4. Run one of the example configs:
```bash
# text-only
axolotl train examples/mistral4/qlora-text.yml # no experts ~69 GiB, experts ~93 GiB
axolotl train examples/mistral4/fft-text.yml
# text + vision
# run: wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
axolotl train examples/mistral4/qlora-vision.yml # no experts ~68 GiB
axolotl train examples/mistral4/fft-vision.yml
```
Note: FFT configs provided as reference. Please adjust hyperparameters as needed.
## Reasoning Effort
The chat template supports a `reasoning_effort` variable to control the model's reasoning depth:
- `"none"` — instruct mode (default)
- `"high"` — reasoning mode with explicit thinking steps
Pass it via `chat_template_kwargs` under your dataset config:
```yaml
datasets:
- path: your/dataset
type: chat_template
chat_template_kwargs:
reasoning_effort: high
```
## Thinking Support
The chat template supports a `thinking` content type in assistant messages for training on reasoning traces (rendered as `[THINK]...[/THINK]` blocks).
To use thinking datasets, add the `thinking` mapping via `message_property_mappings`:
```yaml
datasets:
- path: your/thinking-dataset
type: chat_template
message_property_mappings:
role: role
content: content
thinking: thinking
chat_template_kwargs:
reasoning_effort: high
```
See the [Magistral thinking guide](../magistral/think/README.md) for dataset format details.
## Tips
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
## Related Resources
- [MistralAI Mistral Small 4 Blog](https://mistral.ai/news/mistral-small-4)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,58 @@
base_model: mistralai/Mistral-Small-4-119B-2603
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_sonicmoe: true
# only train language model layers, freeze vision tower
unfrozen_parameters:
- model.language_model.*
- lm_head
- embed_tokens
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: false
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Mistral4DecoderLayer
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -0,0 +1,57 @@
base_model: mistralai/Mistral-Small-4-119B-2603
processor_type: AutoProcessor
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_sonicmoe: true
# vision requirements
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: Nanobit/text-vision-2k-test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
sequence_len: 2048
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: false
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Mistral4DecoderLayer
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -0,0 +1,58 @@
base_model: mistralai/Mistral-Small-4-119B-2603
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
# uncomment to train on expert layers
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# lora_mlp_kernel: false
# lora_qkv_kernel: false
# lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -0,0 +1,63 @@
base_model: mistralai/Mistral-Small-4-119B-2603
processor_type: AutoProcessor
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
# vision chat template requirements
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: Nanobit/text-vision-2k-test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
sequence_len: 2048
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
# uncomment to train on expert layers
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# lora_mlp_kernel: false
# lora_qkv_kernel: false
# lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -0,0 +1,57 @@
base_model: nvidia/Nemotron-Mini-4B-Instruct
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/nemotron-mini-4b-qlora
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- up_proj
- down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
special_tokens:

View File

@@ -0,0 +1,59 @@
base_model: Qwen/Qwen3.5-27B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# Full fine-tune (FFT) of the text-only path of Qwen3.5-27B.
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
# Freeze vision encoder
unfrozen_parameters:
- model\.language_model\..*
- lm_head\..*
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -0,0 +1,49 @@
base_model: Qwen/Qwen3.5-9B
processor_type: AutoProcessor
# Required for multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: qwen3_5
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -1,10 +1,6 @@
base_model: Qwen/Qwen3.5-7B
base_model: Qwen/Qwen3.5-9B
processor_type: AutoProcessor
# Qwen3.5-7B and above are early-fusion VLMs (Qwen3_5ForConditionalGeneration).
# Vision and text tokens are processed together by the same transformer layers.
# Note: Qwen3.5-2B is a text-only model — the smallest VLM is Qwen3.5-7B.
# These 3 lines are required for vision/multimodal training
skip_prepare_dataset: true
remove_unused_columns: false

View File

@@ -1,15 +1,20 @@
# Finetune Qwen3.5 with Axolotl
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-68452f3bc6e4b7cfb4e1c803) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. Models from 7B onwards are early-fusion vision-language models (`Qwen3_5ForConditionalGeneration`), meaning vision and text tokens are processed through the same transformer stack. The 2B variant is text-only.
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. All Qwen3.5 models are early-fusion vision-language models: dense variants use `Qwen3_5ForConditionalGeneration` and MoE variants use `Qwen3_5MoeForConditionalGeneration`.
Vision and text tokens are processed through the same transformer stack. The configs below train on text-only data unless noted otherwise. See `9b-lora-vision.yaml` for a multimodal example.
Available configs:
| Config | Model | Type |
|---|---|---|
| `27b-qlora.yaml` | Qwen3.5-27B | Dense VLM, text-only path |
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only path |
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only path |
| `7b-lora-vision.yaml` | Qwen3.5-7B | Vision+text (multimodal) |
| Config | Model | Type | Peak VRAM |
|---|---|---|---|
| `27b-qlora.yaml` | Qwen3.5-27B | Dense VLM, text-only QLoRA | ~47 GiB |
| `27b-fft.yaml` | Qwen3.5-27B | Dense VLM, text-only FFT (vision frozen) | ~53 GiB |
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA | — |
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA | — |
| `9b-lora-vision.yaml` | Qwen3.5-9B | Vision+text LoRA, single GPU | — |
| `9b-fft-vision.yaml` | Qwen3.5-9B | Vision+text FFT, single GPU | ~61 GiB |
## Getting started
@@ -29,23 +34,31 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
# Dense 27B text-only (QLoRA, ~47 GiB VRAM with sample packing)
axolotl train examples/qwen3.5/27b-qlora.yaml
# Dense 27B text-only FFT with vision encoder frozen (~53 GiB, single 80 GiB GPU)
axolotl train examples/qwen3.5/27b-fft.yaml
# MoE 35B-A3B text-only (QLoRA)
axolotl train examples/qwen3.5/35b-a3b-moe-qlora.yaml
# MoE 122B-A10B text-only (QLoRA)
axolotl train examples/qwen3.5/122b-a10b-moe-qlora.yaml
# 7B vision+text (LoRA, multimodal dataset)
axolotl train examples/qwen3.5/7b-lora-vision.yaml
# 9B vision+text (LoRA, multimodal dataset)
axolotl train examples/qwen3.5/9b-lora-vision.yaml
# 9B vision+text FFT, single 80 GiB GPU (~61 GiB peak)
axolotl train examples/qwen3.5/9b-fft-vision.yaml
```
### TIPS
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
- You can run a full finetuning by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
- For **text-only FFT** on 27B, use `27b-fft.yaml` which sets `unfrozen_parameters` to freeze the vision encoder (`model.visual.*`) — this avoids wasting optimizer state on parameters that receive no gradient from text-only data.
- You can run a full finetuning of smaller configs by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `7b-lora-vision.yaml`.
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `9b-lora-vision.yaml`.
- The Gated DeltaNet linear attention layers (`linear_attn.*`) can optionally be added to `lora_target_modules` — they are commented out by default.
## Optimization Guides

View File

@@ -75,4 +75,4 @@ axolotl-contribs-mit==0.0.6
# telemetry
posthog==6.7.11
mistral-common==1.8.8
mistral-common==1.10.0

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"'
)

View File

@@ -90,9 +90,8 @@ class ModalCloud(Cloud):
# grab the sha256 hash from docker hub for this image+tag
# this ensures that we always get the latest image for this tag, even if it's already cached
try:
manifest = subprocess.check_output( # nosec B602
f"docker manifest inspect {docker_image}",
shell=True,
manifest = subprocess.check_output( # nosec
["docker", "manifest", "inspect", docker_image],
).decode("utf-8")
sha256_hash = json.loads(manifest)["manifests"][0]["digest"]
except subprocess.CalledProcessError:

View File

@@ -11,7 +11,7 @@ from urllib.parse import urlparse
import requests
import torch
import yaml
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils import is_torch_bf16_gpu_available, is_torch_tf32_available
from axolotl.integrations.base import PluginManager
from axolotl.telemetry.errors import send_errors
@@ -300,7 +300,7 @@ def load_cfg(
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except:
except (RuntimeError, AssertionError):
gpu_version = None
prepare_plugins(cfg)
@@ -310,6 +310,7 @@ def load_cfg(
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"fp8": compute_supports_fp8(),
"tf32": is_torch_tf32_available(),
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},

View File

@@ -196,12 +196,10 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
state.wait_for_everyone()
LOG.info(
f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}",
main_process_only=True,
)
LOG.info(
"Merged weights are only the safetensors and doesn't include the model configuration "
f"or tokenizer which may be found in {parsed_cfg.output_dir}.",
main_process_only=True,
)

View File

@@ -38,7 +38,18 @@ def do_vllm_serve(
cfg = load_cfg(config)
model = cfg.base_model
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
# Determine serve module: explicit CLI/config > auto-select from vllm_lora_sync > default
serve_module = cli_args.get("serve_module") or getattr(
cfg.vllm, "serve_module", None
)
if (
serve_module is None
and getattr(cfg, "trl", None)
and getattr(cfg.trl, "vllm_lora_sync", False)
):
serve_module = "axolotl.scripts.vllm_serve_lora"
if serve_module is None:
serve_module = "trl.scripts.vllm_serve"
vllm_serve_main = __import__(serve_module, fromlist=["main"]).main
tensor_parallel_size = 1
data_parallel_size = 1
@@ -68,7 +79,7 @@ def do_vllm_serve(
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
)
vllm_script_args = AxolotlScriptArguments(
base_kwargs = dict(
model=model,
tensor_parallel_size=tensor_parallel_size,
data_parallel_size=data_parallel_size,
@@ -78,7 +89,21 @@ def do_vllm_serve(
dtype=dtype,
max_model_len=max_model_len,
enable_prefix_caching=enable_prefix_caching,
reasoning_parser=reasoning_parser,
enable_reasoning=enable_reasoning,
)
# Use LoRAScriptArguments when serving with native LoRA support
if serve_module == "axolotl.scripts.vllm_serve_lora":
from axolotl.scripts.vllm_serve_lora import LoRAScriptArguments
lora_kwargs = {}
if hasattr(cfg, "lora_r") and cfg.lora_r:
lora_kwargs["max_lora_rank"] = cfg.lora_r
vllm_script_args = LoRAScriptArguments(**base_kwargs, **lora_kwargs)
else:
vllm_script_args = AxolotlScriptArguments(
**base_kwargs,
reasoning_parser=reasoning_parser,
enable_reasoning=enable_reasoning,
)
vllm_serve_main(vllm_script_args)

View File

@@ -16,6 +16,7 @@ MOE_ARCH_BLOCK = {
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
"deepseek_v3": "DeepseekV3MoE",
"mistral4": "Mistral4MoE",
"gpt_oss": "GptOssDecoderLayer",
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
"afmoe": "AfmoeMoE",

View File

@@ -67,7 +67,7 @@ class JsonToJsonlConverter:
self.json_parser = json_parser
self.jsonl_serializer = jsonl_serializer
def convert(self, input_file_path, output_file_path):
def convert(self, input_file_path):
content = self.file_reader.read(input_file_path)
data = self.json_parser.parse(content)
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations

View File

@@ -250,7 +250,7 @@ class TrainerBuilderBase(abc.ABC):
def _configure_precision_settings(self, training_args_kwargs: dict):
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
training_args_kwargs["tf32"] = self.cfg.tf32
training_args_kwargs["tf32"] = True if self.cfg.tf32 is True else False
if self.cfg.bf16 == "full":
training_args_kwargs["bf16_full_eval"] = True
else:

View File

@@ -421,6 +421,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
# TRL's RewardTrainer validates num_labels=1 on pre-loaded models; ensure the
# config reflects this regardless of how the model was instantiated.
if (
self.cfg.reward_model
and getattr(self.model.config, "num_labels", None) != 1
):
self.model.config.num_labels = 1
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,

View File

@@ -54,8 +54,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
async_grpo = bool(
self.cfg.trl
and (
getattr(self.cfg.trl, "async_prefetch", False)
or getattr(self.cfg.trl, "use_data_producer", False)
)
)
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1
sequence_parallel=self.cfg.context_parallel_size > 1,
async_grpo=async_grpo,
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
@@ -151,7 +159,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
training_args_cls = GRPOStrategy.get_training_args_class()
async_grpo = bool(
self.cfg.trl
and (
getattr(self.cfg.trl, "async_prefetch", False)
or getattr(self.cfg.trl, "use_data_producer", False)
)
)
training_args_cls = GRPOStrategy.get_training_args_class(
async_grpo=async_grpo
)
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
if self.cfg.rl is RLType.GDPO:
@@ -217,13 +234,36 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_kwargs, trainer_cls
)
trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
# Allow FP8-quantized models to be fine-tuned with LoRA adapters.
# transformers' validate_quantization_for_training blocks FP8 because
# hf_quantizer.is_trainable is False, but LoRA only trains the adapters
# (base weights stay frozen in FP8).
_orig_validate_quant = None
if (
self.cfg.adapter
and hasattr(self.model, "is_quantized")
and self.model.is_quantized
):
import transformers.trainer as _trainer_module
_orig_validate_quant = _trainer_module.validate_quantization_for_training
_trainer_module.validate_quantization_for_training = lambda model: None
try:
trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
finally:
if _orig_validate_quant is not None:
import transformers.trainer as _trainer_module
_trainer_module.validate_quantization_for_training = (
_orig_validate_quant
)
if self.cfg.fsdp_config or self.cfg.fsdp:
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:

View File

@@ -9,8 +9,9 @@ from huggingface_hub import snapshot_download
from requests import HTTPError
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
from axolotl.core.trainers.grpo.args import AxolotlAsyncGRPOConfig, AxolotlGRPOConfig
from axolotl.core.trainers.grpo.trainer import (
AxolotlAsyncGRPOTrainer,
AxolotlGRPOSequenceParallelTrainer,
AxolotlGRPOTrainer,
)
@@ -27,14 +28,31 @@ class GRPOStrategy:
@classmethod
def get_trainer_class(
cls, sequence_parallel: bool
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]:
cls,
sequence_parallel: bool,
async_grpo: bool = False,
) -> (
type[AxolotlGRPOTrainer]
| type[AxolotlGRPOSequenceParallelTrainer]
| type[AxolotlAsyncGRPOTrainer]
):
if sequence_parallel and async_grpo:
raise ValueError(
"sequence_parallel and async_grpo cannot both be enabled. "
"Disable one of context_parallel_size > 1 or async_prefetch/use_data_producer."
)
if sequence_parallel:
return AxolotlGRPOSequenceParallelTrainer
if async_grpo:
return AxolotlAsyncGRPOTrainer
return AxolotlGRPOTrainer
@classmethod
def get_training_args_class(cls) -> type[AxolotlGRPOConfig]:
def get_training_args_class(
cls, async_grpo: bool = False
) -> type[AxolotlGRPOConfig] | type[AxolotlAsyncGRPOConfig]:
if async_grpo:
return AxolotlAsyncGRPOConfig
return AxolotlGRPOConfig
@classmethod
@@ -124,13 +142,63 @@ class GRPOStrategy:
grpo_args_kwargs["epsilon_high"] = trl.epsilon_high
if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
grpo_args_kwargs["use_liger_kernel"] = trl.use_liger_loss
if trl.multi_objective_aggregation is not None:
grpo_args_kwargs["multi_objective_aggregation"] = (
trl.multi_objective_aggregation
)
# Async GRPO fields
if getattr(trl, "use_data_producer", None) is not None:
grpo_args_kwargs["use_data_producer"] = trl.use_data_producer
if getattr(trl, "async_prefetch", None) is not None:
grpo_args_kwargs["async_prefetch"] = trl.async_prefetch
if getattr(trl, "prefetch_depth", None) is not None:
grpo_args_kwargs["prefetch_depth"] = trl.prefetch_depth
if getattr(trl, "vllm_sync_interval", None) is not None:
grpo_args_kwargs["vllm_sync_interval"] = trl.vllm_sync_interval
if getattr(trl, "streaming_partial_batch", None) is not None:
grpo_args_kwargs["streaming_partial_batch"] = trl.streaming_partial_batch
if getattr(trl, "streaming_min_groups", None) is not None:
grpo_args_kwargs["streaming_min_groups"] = trl.streaming_min_groups
if getattr(trl, "vllm_importance_sampling_correction", None) is not None:
grpo_args_kwargs["vllm_importance_sampling_correction"] = (
trl.vllm_importance_sampling_correction
)
if getattr(trl, "vllm_importance_sampling_mode", None) is not None:
grpo_args_kwargs["vllm_importance_sampling_mode"] = (
trl.vllm_importance_sampling_mode
)
if getattr(trl, "vllm_importance_sampling_cap", None) is not None:
grpo_args_kwargs["vllm_importance_sampling_cap"] = (
trl.vllm_importance_sampling_cap
)
if getattr(trl, "off_policy_mask_threshold", None) is not None:
grpo_args_kwargs["off_policy_mask_threshold"] = (
trl.off_policy_mask_threshold
)
if getattr(trl, "use_bias_correction_kl", None) is not None:
grpo_args_kwargs["use_bias_correction_kl"] = trl.use_bias_correction_kl
# Fast Async GRPO fields
if getattr(trl, "reward_num_workers", None) is not None:
grpo_args_kwargs["reward_num_workers"] = trl.reward_num_workers
if getattr(trl, "replay_buffer_size", None) is not None:
grpo_args_kwargs["replay_buffer_size"] = trl.replay_buffer_size
if getattr(trl, "replay_recompute_logps", None) is not None:
grpo_args_kwargs["replay_recompute_logps"] = trl.replay_recompute_logps
if getattr(trl, "reroll_start_fraction", None) is not None:
grpo_args_kwargs["reroll_start_fraction"] = trl.reroll_start_fraction
if getattr(trl, "reroll_max_groups", None) is not None:
grpo_args_kwargs["reroll_max_groups"] = trl.reroll_max_groups
if getattr(trl, "skip_zero_advantage_batches", None) is not None:
grpo_args_kwargs["skip_zero_advantage_batches"] = (
trl.skip_zero_advantage_batches
)
if getattr(trl, "vllm_lora_sync", None) is not None:
grpo_args_kwargs["vllm_lora_sync"] = trl.vllm_lora_sync
return grpo_args_kwargs
@classmethod

View File

@@ -6,6 +6,7 @@ from dataclasses import dataclass
from trl import GRPOConfig
from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOConfig
from axolotl.core.training_args import AxolotlTrainingMixins
@@ -14,3 +15,10 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""Axolotl GRPO Config for GRPO training"""
context_parallel_size: int | None = None
@dataclass
class AxolotlAsyncGRPOConfig(AxolotlTrainingMixins, FastAsyncGRPOConfig):
"""Axolotl Async GRPO Config — adds async prefetch, streaming scoring, and IS correction."""
context_parallel_size: int | None = None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,768 @@
# Copyright 2020-2026 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Experimental GRPO extensions: parallel reward workers, replay buffer,
deferred re-roll, and zero-advantage skipping.
These features are built as subclasses of GRPOTrainer and GRPODataProducer,
using the hook system (_compute_rewards_for_batch, _post_advantage_hook,
_pre_produce_hook) defined in the base classes.
"""
from __future__ import annotations
import asyncio
import logging
import threading
from dataclasses import dataclass, field
import torch
from torch import nn
from trl import GRPOTrainer
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOConfig,
AsyncGRPOTrainer,
GRPODataProducer,
)
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Extended config
# ---------------------------------------------------------------------------
@dataclass
class FastAsyncGRPOConfig(AsyncGRPOConfig):
"""GRPOConfig with additional experimental parameters."""
reward_num_workers: int = field(
default=1,
metadata={
"help": "Number of persistent subprocess workers for parallel reward computation. Each worker has its "
"own main thread so signal.alarm() (used by math_verify) works correctly. Work is sharded across "
"workers by prompt groups. Only used with use_data_producer=True and non-nn.Module reward functions."
},
)
replay_buffer_size: int = field(
default=0,
metadata={
"help": "[Experimental, disabled by default] Size of the replay buffer for storing high-signal rollout "
"groups. When > 0, groups with reward variance are cached and used to replace zero-signal groups "
"(where all rewards are identical). Set to 0 to disable. Only used with use_data_producer=True."
},
)
replay_recompute_logps: bool = field(
default=True,
metadata={
"help": "When True (default), recompute old_per_token_logps for replayed groups using the current "
"training model. This fixes the importance sampling mismatch that occurs when replaying stale data. "
"Only relevant when replay_buffer_size > 0."
},
)
reroll_start_fraction: float = field(
default=0.5,
metadata={
"help": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts "
"(where all rewards in a group are identical) are buffered and re-injected into later batches when the "
"model is more likely to solve them. Set to 1.0 to disable. Only used with use_data_producer=True."
},
)
reroll_max_groups: int = field(
default=1,
metadata={
"help": "Maximum number of prompt groups to replace with re-roll candidates per batch. Higher values "
"increase data utilization but reduce prompt diversity. Only used with use_data_producer=True."
},
)
skip_zero_advantage_batches: bool = field(
default=True,
metadata={
"help": "When True, skip gradient computation for micro-batches where all advantages are zero (no learning "
"signal). This avoids the forward/backward pass entirely when no learning signal is present. The step is "
"logged with skipped_zero_adv_batches=1 for monitoring."
},
)
vllm_lora_sync: bool = field(
default=False,
metadata={
"help": "When True, sync LoRA adapter weights to vLLM via filesystem instead of merging into base model "
"and NCCL-broadcasting all parameters. vLLM loads the adapter natively using Punica kernels. "
"Requires vllm_serve_lora serve module (auto-selected when this is True). "
"Syncs only LoRA adapter weights (much smaller) vs full merged model. Legacy merge behavior is used when False."
},
)
# ---------------------------------------------------------------------------
# Extended data producer with re-roll injection
# ---------------------------------------------------------------------------
class RerollDataProducer(GRPODataProducer):
"""GRPODataProducer that injects re-roll candidates into prompt batches.
Reads from the trainer's ``_reroll_buffer`` (populated by
``GRPOExperimentalTrainer._post_advantage_hook``) and replaces the
last N prompt groups with previously-failed prompts.
"""
def _pre_produce_hook(self, inputs: list, global_step: int) -> list:
trainer = self._trainer
reroll_buf = getattr(trainer, "_reroll_buffer", None)
reroll_lock = getattr(trainer, "_reroll_lock", None)
if reroll_buf is None or reroll_lock is None:
return inputs
max_steps = getattr(trainer.args, "max_steps", -1)
start_frac = getattr(trainer.args, "reroll_start_fraction", 1.0)
max_groups = getattr(trainer.args, "reroll_max_groups", 1)
reroll_start_step = (
max(1, int(max_steps * start_frac)) if max_steps > 0 else float("inf")
)
if global_step < reroll_start_step:
return inputs
with reroll_lock:
n_to_take = min(max_groups, len(reroll_buf))
reroll_prompts = [reroll_buf.pop(0) for _ in range(n_to_take)]
if reroll_prompts:
num_gen = self._num_generations
n_groups = len(inputs) // num_gen
for i, reroll_prompt in enumerate(reroll_prompts):
group_idx = n_groups - 1 - i
if group_idx < 0:
break
start = group_idx * num_gen
for j in range(num_gen):
inputs[start + j] = reroll_prompt
logger.info(
f"[REROLL] Step {global_step}: replaced {len(reroll_prompts)}/{n_groups} prompt groups "
f"with deferred re-roll candidates ({len(reroll_buf)} remaining)"
)
return inputs
# ---------------------------------------------------------------------------
# Persistent reward subprocess pool
# ---------------------------------------------------------------------------
def _persistent_reward_worker(conn):
"""Long-lived reward worker. Receives work items, returns results."""
while True:
try:
msg = conn.recv()
except EOFError:
break
if msg is None: # Shutdown signal
break
(
reward_funcs,
prompts,
completions,
completion_ids_list,
inputs,
reward_func_names,
) = msg
try:
keys = [
key
for key in inputs[0]
if key not in ["prompt", "completion", "completion_ids"]
]
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
results = []
for reward_func, _reward_func_name in zip(
reward_funcs, reward_func_names, strict=True
):
output = reward_func(
prompts=prompts,
completions=completions,
completion_ids=completion_ids_list,
**reward_kwargs,
)
results.append(
[float(r) if r is not None else float("nan") for r in output]
)
conn.send(results)
except Exception:
conn.send(None)
# ---------------------------------------------------------------------------
# Extended trainer
# ---------------------------------------------------------------------------
class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
"""GRPOTrainer with experimental extensions.
Adds:
- Parallel reward subprocess workers (``reward_num_workers``)
- Replay buffer for high-signal group reuse (``replay_buffer_size``)
- Deferred re-roll of failed prompts (``reroll_start_fraction``)
- Zero-advantage micro-batch skipping
"""
def __init__(self, *args, **kwargs):
# These must be initialized before super().__init__() because
# _create_data_producer (called during super().__init__) needs them.
self._reroll_buffer: list = []
self._reroll_lock = threading.Lock()
# Temporarily suppress the base class's Liger + OPSM validation check,
# since this subclass supports it via a custom compute_liger_loss override.
grpo_args = kwargs.get("args")
if grpo_args is None:
for a in args:
if hasattr(a, "off_policy_mask_threshold"):
grpo_args = a
break
saved_threshold = None
if grpo_args is not None and getattr(grpo_args, "use_liger_kernel", False):
saved_threshold = grpo_args.off_policy_mask_threshold
grpo_args.off_policy_mask_threshold = None
super().__init__(*args, **kwargs)
if saved_threshold is not None:
grpo_args.off_policy_mask_threshold = saved_threshold
self.off_policy_mask_threshold = saved_threshold
# Replay buffer
if getattr(self.args, "replay_buffer_size", 0) > 0:
self._replay_buffer = ReplayBuffer(max_size=self.args.replay_buffer_size)
else:
self._replay_buffer = None
self._replay_recompute_logps = getattr(
self.args, "replay_recompute_logps", True
)
# Reward worker pool (lazy-initialized)
self._reward_workers = None
# -- Factory override: use RerollDataProducer ----------------------------
def _create_data_producer(self, args, train_dataset):
"""Override to use RerollDataProducer for re-roll prompt injection."""
from axolotl.core.trainers.grpo.async_trainer import (
AsyncDataProducer,
ProducerConfig,
)
producer_config = ProducerConfig(
mini_epochs=args.num_iterations,
max_rollouts=None,
eval_during_produce=False,
empty_cache_before_produce=True,
empty_cache_after_produce=True,
async_prefetch=args.async_prefetch,
prefetch_depth=args.prefetch_depth,
)
data_producer = RerollDataProducer(
config=producer_config,
prompt_dataset=train_dataset,
num_generations=self.num_generations,
generation_batch_size=args.generation_batch_size,
train_batch_size=args.per_device_train_batch_size,
steps_per_generation=args.steps_per_generation,
shuffle_dataset=self.shuffle_dataset,
seed=args.seed,
)
data_producer.set_trainer(self)
if args.async_prefetch:
data_producer = AsyncDataProducer(
data_producer,
background_produce_kwargs={"skip_policy_logps": True},
)
return data_producer
# -- Reward worker pool --------------------------------------------------
def _get_reward_workers(self):
"""Return a list of persistent reward worker subprocesses (lazy-initialized)."""
import multiprocessing as _mp
num_workers = getattr(self.args, "reward_num_workers", 1)
if num_workers < 1:
num_workers = 1
if self._reward_workers is not None:
alive = all(proc.is_alive() for conn, proc in self._reward_workers)
if alive and len(self._reward_workers) == num_workers:
return self._reward_workers
self._shutdown_reward_workers()
workers = []
for _ in range(num_workers):
parent_conn, child_conn = _mp.Pipe()
proc = _mp.Process(
target=_persistent_reward_worker, args=(child_conn,), daemon=True
)
proc.start()
child_conn.close()
workers.append((parent_conn, proc))
self._reward_workers = workers
return workers
def _shutdown_reward_workers(self):
"""Shut down all persistent reward workers."""
if self._reward_workers is None:
return
for conn, proc in self._reward_workers:
try:
conn.send(None)
proc.join(timeout=5)
except Exception:
pass
try:
conn.close()
except Exception:
pass
self._reward_workers = None
# -- Hook overrides ------------------------------------------------------
def _compute_rewards_for_batch(
self, inputs, prompts, completions, completion_ids_list
):
"""Dispatch rewards to parallel subprocess workers (synchronous wrapper)."""
self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
return self._collect_reward_workers(
inputs, prompts, completions, completion_ids_list
)
def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list):
"""Send reward work to subprocess workers (non-blocking).
Results are collected later by _collect_reward_workers, allowing GPU
logprob computation to overlap with CPU reward computation.
"""
reward_can_bg = all(
callable(rf)
and not isinstance(rf, nn.Module)
and not asyncio.iscoroutinefunction(rf)
for rf in self.reward_funcs
)
num_workers = getattr(self.args, "reward_num_workers", 1)
if not reward_can_bg or num_workers <= 1:
# Can't parallelize — store args for sync fallback in collect
self._reward_workers_used = None
self._pending_reward_args = (
inputs,
prompts,
completions,
completion_ids_list,
)
return
workers = self._get_reward_workers()
num_generations = self.num_generations
num_prompts = len(prompts)
num_groups = num_prompts // num_generations
# Shard by prompt groups across workers
groups_per_worker = max(1, (num_groups + len(workers) - 1) // len(workers))
workers_used = []
for w_idx, (conn, _proc) in enumerate(workers):
g_start = w_idx * groups_per_worker
g_end = min((w_idx + 1) * groups_per_worker, num_groups)
if g_start >= num_groups:
break
s_start = g_start * num_generations
s_end = g_end * num_generations
conn.send(
(
self.reward_funcs,
prompts[s_start:s_end],
completions[s_start:s_end],
completion_ids_list[s_start:s_end],
inputs[s_start:s_end],
self.reward_func_names,
)
)
workers_used.append(conn)
self._reward_workers_used = workers_used
self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)
def _collect_reward_workers(
self, inputs, prompts, completions, completion_ids_list
):
"""Collect reward results from subprocess workers (blocks until done)."""
from accelerate.utils import gather
workers_used = getattr(self, "_reward_workers_used", None)
args = getattr(self, "_pending_reward_args", None)
self._reward_workers_used = None
self._pending_reward_args = None
if workers_used is None:
# Sync fallback — compute on main thread
if args is not None:
return self._calculate_rewards(*args)
return self._calculate_rewards(
inputs, prompts, completions, completion_ids_list
)
device = self.accelerator.device
num_prompts = len(args[1]) if args else len(prompts)
# Collect results from workers
all_worker_results = []
any_failed = False
for conn in workers_used:
result = conn.recv()
if result is None:
any_failed = True
# Drain remaining workers to prevent stale results in pipes
for remaining_conn in workers_used:
if remaining_conn is not conn:
try:
remaining_conn.recv()
except Exception:
pass
break
all_worker_results.append(result)
if not any_failed:
rewards_per_func = torch.zeros(
num_prompts, len(self.reward_funcs), device=device
)
offset = 0
for worker_result in all_worker_results:
chunk_size = len(worker_result[0])
for i, result in enumerate(worker_result):
rewards_per_func[offset : offset + chunk_size, i] = torch.tensor(
result, dtype=torch.float32, device=device
)
offset += chunk_size
return gather(rewards_per_func)
# Fallback to main thread on failure
if args is not None:
return self._calculate_rewards(*args)
return self._calculate_rewards(
inputs, prompts, completions, completion_ids_list
)
def _post_advantage_hook(
self,
data: dict,
rewards_per_func,
advantages,
inputs: list,
num_generations: int,
mode: str,
s_start: int | None = None,
s_end: int | None = None,
is_last_chunk: bool = True,
) -> None:
"""Replay buffer store/replace + re-roll buffering."""
from trl.models.utils import disable_gradient_checkpointing
# -- Replay buffer: store high-signal groups --
if self._replay_buffer is not None:
local_grouped = rewards_per_func.view(
-1, num_generations, len(self.reward_funcs)
)
per_group_std = local_grouped.std(dim=1)
has_signal = (per_group_std > 0).any(dim=1)
offset = s_start or 0
if has_signal.any():
grouped_adv = advantages.view(-1, num_generations)
replay_scores = grouped_adv.abs().sum(dim=1) * per_group_std.sum(dim=1)
for group_idx in has_signal.nonzero(as_tuple=True)[0]:
gi = group_idx.item()
start = offset + gi * num_generations
end = start + num_generations
group_data = {}
for key in data:
val = data[key]
if (
isinstance(val, torch.Tensor)
and val.dim() > 0
and val.size(0) >= end
):
group_data[key] = val[start:end].clone()
self._replay_buffer.add(replay_scores[gi].item(), group_data)
# Replace zero-signal groups with high-signal replay buffer entries
# Only in non-streaming path (s_start is None) — streaming scores
# groups incrementally, so replacement + logprob recompute would be
# too expensive per chunk.
n_replaced = 0
if s_start is None:
no_signal = ~has_signal
replaced_ranges = []
if no_signal.any() and len(self._replay_buffer) > 0:
for group_idx in no_signal.nonzero(as_tuple=True)[0]:
sampled = self._replay_buffer.sample(1)
if sampled is None:
break
sampled_group = sampled[0]
gi = group_idx.item()
start = offset + gi * num_generations
end = start + num_generations
for key, val in sampled_group.items():
if key in data and isinstance(data[key], torch.Tensor):
src = val.to(data[key].device)
tgt_seq_len = (
data[key].size(1) if data[key].dim() > 1 else None
)
if start >= data[key].size(0) or end > data[key].size(
0
):
continue
if tgt_seq_len is not None:
if src.size(1) <= tgt_seq_len:
data[key][start:end] = 0
data[key][start:end, : src.size(1)] = src
else:
data[key][start:end] = src[:, :tgt_seq_len]
else:
data[key][start:end] = src
replaced_ranges.append((start, end))
n_replaced += 1
# Recompute old_per_token_logps for replayed groups
if (
n_replaced > 0
and self._replay_recompute_logps
and "old_per_token_logps" in data
):
with (
torch.no_grad(),
disable_gradient_checkpointing(
self.model, self.args.gradient_checkpointing_kwargs
),
):
for r_start, r_end in replaced_ranges:
r_ids = torch.cat(
[
data["prompt_ids"][r_start:r_end],
data["completion_ids"][r_start:r_end],
],
dim=1,
)
r_mask = torch.cat(
[
data["prompt_mask"][r_start:r_end],
data["completion_mask"][r_start:r_end],
],
dim=1,
)
r_logits_to_keep = data["completion_ids"].size(1)
r_fwd_kwargs = {}
for fk in (
"pixel_values",
"image_grid_thw",
"pixel_attention_mask",
"image_sizes",
"token_type_ids",
"mm_token_type_ids",
):
if fk in data:
r_fwd_kwargs[fk] = data[fk]
r_logps, _ = self._get_per_token_logps_and_entropies(
self.model,
r_ids,
r_mask,
r_logits_to_keep,
r_end - r_start,
**r_fwd_kwargs,
)
data["old_per_token_logps"][r_start:r_end] = r_logps
if n_replaced > 0:
self._metrics[mode]["replay_buffer_replacements"].append(
float(n_replaced)
)
if is_last_chunk:
self._metrics[mode]["replay_buffer_size"].append(
float(len(self._replay_buffer))
)
# -- Re-roll buffer: store failed prompts --
if getattr(self.args, "reroll_start_fraction", 1.0) < 1.0:
grouped_rewards = rewards_per_func.view(
-1, num_generations, len(self.reward_funcs)
)
per_group_std = grouped_rewards.std(dim=1)
per_group_mean = grouped_rewards.mean(dim=1)
zero_signal = (per_group_std == 0).all(dim=1)
all_failed = (per_group_mean.abs() < 1e-6).all(dim=1)
should_reroll = zero_signal & all_failed
_n_buffered = 0
with self._reroll_lock:
for group_idx in should_reroll.nonzero(as_tuple=True)[0]:
idx = group_idx.item() * num_generations
if idx >= len(inputs):
continue
prompt_input = inputs[idx]
self._reroll_buffer.append(prompt_input)
_n_buffered += 1
if _n_buffered > 0:
self._metrics[mode]["reroll_buffered"].append(float(_n_buffered))
if is_last_chunk:
self._metrics[mode]["reroll_buffer_size"].append(
float(len(self._reroll_buffer))
)
# -- Zero-advantage skipping + Liger OPSM ---------------------------------
def compute_liger_loss(self, unwrapped_model, inputs):
"""Liger loss with zero-adv skipping and off-policy sequence masking (OPSM).
The base class Liger path doesn't support OPSM because the fused kernel
doesn't expose per-token logprobs needed for the KL computation. This
override computes them via chunked lm_head matmul (no grad, low memory)
and applies the OPSM to the loss mask before calling the kernel.
"""
if self.args.skip_zero_advantage_batches and torch.all(
inputs["advantages"] == 0
):
mode = "train" if self.model.training else "eval"
self._metrics[mode]["skipped_zero_adv_batches"].append(1.0)
return torch.tensor(
0.0, device=inputs["advantages"].device, requires_grad=True
)
if self.off_policy_mask_threshold is None:
return super().compute_liger_loss(unwrapped_model, inputs)
# OPSM path: need per_token_logps for KL, which Liger kernel doesn't provide
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = (
inputs["completion_ids"],
inputs["completion_mask"],
)
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1)
last_hidden_state = self._get_last_hidden_state(
unwrapped_model,
input_ids,
attention_mask,
logits_to_keep,
inputs.get("pixel_values"),
inputs.get("image_grid_thw"),
inputs.get("pixel_attention_mask"),
inputs.get("image_sizes"),
)
loss_mask = (
completion_mask
if "tool_mask" not in inputs
else completion_mask * inputs["tool_mask"]
)
# Compute per_token_logps via chunked lm_head matmul (no grad, low memory)
lm_weight = unwrapped_model.lm_head.weight
lm_bias = unwrapped_model.lm_head.bias
with torch.no_grad():
per_token_logps_chunks = []
for i in range(last_hidden_state.size(0)):
chunk_logits = torch.matmul(last_hidden_state[i : i + 1], lm_weight.t())
if lm_bias is not None:
chunk_logits = chunk_logits + lm_bias
chunk_lps = (
chunk_logits.float()
.log_softmax(-1)
.gather(-1, completion_ids[i : i + 1].unsqueeze(-1))
.squeeze(-1)
)
per_token_logps_chunks.append(chunk_lps)
del chunk_logits
per_token_logps = torch.cat(per_token_logps_chunks, dim=0)
advantages = inputs["advantages"]
if advantages.dim() == 1:
advantages_2d = advantages.unsqueeze(1)
else:
advantages_2d = advantages
sampling_per_token_logps = inputs.get("sampling_per_token_logps")
if sampling_per_token_logps is None:
sampling_per_token_logps = inputs.get("old_per_token_logps")
if sampling_per_token_logps is None:
sampling_per_token_logps = per_token_logps
off_policy_mask = GRPOTrainer.get_off_policy_mask(
advantages=advantages_2d,
per_token_logps=per_token_logps,
sampling_per_token_logps=sampling_per_token_logps,
mask=loss_mask,
off_policy_threshold=self.off_policy_mask_threshold,
)
loss_mask = loss_mask * off_policy_mask
# Call the Liger fused kernel with OPSM-modified mask
loss, metrics = self.liger_grpo_loss(
_input=last_hidden_state,
lin_weight=unwrapped_model.lm_head.weight,
selected_token_ids=completion_ids,
attention_mask=loss_mask,
advantages=inputs["advantages"],
bias=unwrapped_model.lm_head.bias,
old_per_token_logps=inputs.get("old_per_token_logps"),
ref_per_token_logps=inputs.get("ref_per_token_logps"),
vllm_is_ratio=inputs.get("importance_sampling_ratio"),
)
mean_kl = metrics[0] if self.beta != 0.0 else None
clip_ratio = metrics[-1]
mode = "train" if self.model.training else "eval"
if self.beta != 0.0:
self._metrics[mode]["kl"].append(
self.accelerator.gather(mean_kl).mean().item()
)
self._metrics[mode]["clip_ratio"].append(
self.accelerator.gather(clip_ratio).mean().item()
)
normalizer = (
self.current_gradient_accumulation_steps if mode == "train" else 1.0
)
return loss / normalizer
def _compute_loss(self, model, inputs):
if self.args.skip_zero_advantage_batches and torch.all(
inputs["advantages"] == 0
):
mode = "train" if self.model.training else "eval"
self._metrics[mode]["skipped_zero_adv_batches"].append(1.0)
# Create zero loss with grad_fn. DeepSpeed requires grad_fn != None.
# With ZeRO-3, parameters are partitioned (shape=[0], requires_grad=False)
# so we can't just do `(p * 0).sum()`. Instead, do a tiny forward pass
# with a single token to create a proper computation graph.
prompt_ids = inputs["prompt_ids"][:1, :1] # (1, 1)
attn = torch.ones_like(prompt_ids)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
out = model(input_ids=prompt_ids, attention_mask=attn)
return out.logits.sum() * 0
return super()._compute_loss(model, inputs)

View File

@@ -0,0 +1,44 @@
"""Simple replay buffer for storing and sampling high-signal rollout groups."""
import heapq
import torch
class ReplayBuffer:
"""Min-heap replay buffer that keeps the highest-scoring rollout groups.
Groups are scored by signal quality (advantage magnitude * reward variance).
When sampling, groups are drawn proportional to their scores.
"""
def __init__(self, max_size: int):
self.max_size = max_size
self._heap: list[tuple[float, int, dict]] = [] # min-heap of (score, id, data)
self._counter = 0 # unique tiebreaker for heap
def __len__(self):
return len(self._heap)
def add(self, score: float, data: dict):
"""Add a group to the buffer. If full, replaces lowest-scoring entry."""
if self.max_size <= 0:
return
self._counter += 1
if len(self._heap) < self.max_size:
heapq.heappush(self._heap, (score, self._counter, data))
elif score > self._heap[0][0]:
heapq.heapreplace(self._heap, (score, self._counter, data))
def sample(self, num_samples: int) -> list[dict] | None:
"""Sample groups weighted by their scores. Returns None if buffer is empty."""
if self.max_size <= 0 or not self._heap:
return None
scores = torch.tensor([item[0] for item in self._heap], dtype=torch.float32)
scores = scores.clamp(min=1e-8) # avoid zero probabilities
probs = scores / scores.sum()
replacement = num_samples > len(self._heap)
indices = torch.multinomial(
probs, num_samples, replacement=replacement
).tolist()
return [self._heap[i][2] for i in indices]

View File

@@ -40,6 +40,7 @@ from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.grpo_trainer import RewardFunc, nanstd
from trl.trainer.utils import pad
from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOTrainer
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import (
DistributedParallelMixin,
@@ -66,6 +67,19 @@ class AxolotlGRPOTrainer(
_tag_names = ["trl", "grpo", "axolotl"]
class AxolotlAsyncGRPOTrainer(
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
FastAsyncGRPOTrainer,
):
"""Extend AsyncGRPOTrainer with axolotl helpers"""
_tag_names = ["trl", "grpo", "async", "axolotl"]
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling"""

View File

@@ -19,5 +19,4 @@ class CheckpointSaveMixin(Trainer):
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
"Optimizer and scheduler states were not saved - resuming from checkpoints "
"for this training run will not be possible.",
main_process_only=True,
)

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"
```
## Usage
@@ -73,8 +73,10 @@ plugins:
- ministral3
- mistral
- mistral3
- mistral4
- mixtral
- mllama
- nemotron_h
- olmo
- olmo2
- olmo3

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`'
)

View File

@@ -0,0 +1,120 @@
"""Trainer callback for reporting Triton autotune results from scattermoe-lora kernels."""
import logging
import torch
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
LOG = logging.getLogger(__name__)
# Give up looking for autotune data after this many training steps.
_MAX_POLL_STEP = 5
def _get_gpu_info() -> dict:
"""Return basic GPU identification for the current device."""
if not torch.cuda.is_available():
return {}
try:
idx = torch.cuda.current_device()
props = torch.cuda.get_device_properties(idx)
return {
"gpu_name": props.name,
"gpu_compute_capability": f"{props.major}.{props.minor}",
"gpu_memory_bytes": props.total_memory,
}
except Exception: # pylint: disable=broad-exception-caught
return {}
def _get_smem_capacity() -> dict:
"""Return shared memory capacity from the runtime lora_ops module."""
try:
from axolotl.integrations.kernels.autotune_collector import (
_find_lora_ops_module,
)
lora_ops = _find_lora_ops_module()
if lora_ops is None:
return {}
fn = getattr(lora_ops, "_get_smem_capacity", None)
if fn is None:
return {}
return {"smem_capacity_bytes": fn()}
except Exception: # pylint: disable=broad-exception-caught
return {}
class AutotuneReportCallback(TrainerCallback):
"""Reports Triton kernel autotune selections via telemetry.
Fires **once** after the first training step completes (step 1), at
which point the forward and backward passes have both run and the
autotuned kernels have populated their caches. If for some reason
the caches are still empty (e.g. the kernel was never invoked), the
callback retries on subsequent steps up to ``_MAX_POLL_STEP`` and
then stops polling.
After reporting (or giving up) every subsequent ``on_step_end``
call short-circuits on the ``_reported`` flag — zero hot-path cost.
"""
def __init__(self):
self._reported = False
# pylint: disable=unused-argument
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if self._reported:
return
# Lazy import — Triton / scattermoe kernels may not be installed.
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
configs = collect_autotune_configs()
if not configs:
if state.global_step >= _MAX_POLL_STEP:
LOG.debug(
"No autotune data found after %d steps; giving up.",
state.global_step,
)
self._reported = True
return
self._reported = True
from axolotl.telemetry.manager import TelemetryManager
telemetry_manager = TelemetryManager.get_instance()
if not telemetry_manager.enabled:
return
properties = {
"kernel_count": len(configs),
"kernels": configs,
}
properties.update(_get_gpu_info())
properties.update(_get_smem_capacity())
telemetry_manager.send_event(
event_type="scattermoe-autotune",
properties=properties,
)
LOG.info(
"Reported %d scattermoe kernel autotune config(s) to telemetry.",
len(configs),
)

View File

@@ -0,0 +1,114 @@
"""Collect Triton autotune results from scattermoe-lora kernels.
This module reads the ``.cache`` attribute from Triton ``@triton.autotune``
decorated kernel objects and returns structured dicts describing the selected
configurations. It has **no** telemetry dependency — callers decide what to
do with the data.
"""
import logging
import sys
from types import ModuleType
from typing import Any
LOG = logging.getLogger(__name__)
# (human-readable name, attribute on the lora_ops module)
_KERNEL_REGISTRY: list[tuple[str, str]] = [
("scatter2scatter_lora_fwd", "_scatter2scatter_lora"),
("scatter2scatter_lora_dX", "_scatter2scatter_lora_dX"),
("group_bwd_lora", "_group_bwd_lora"),
("group_bwd_lora_fused", "_group_bwd_lora_fused"),
]
# The autotune key declared on every kernel: key=["M", "N", "K"]
_KEY_NAMES: list[str] = ["M", "N", "K"]
def _parse_key_tuple(key_tuple: tuple) -> dict[str, Any]:
"""Turn the autotune cache key tuple into a labelled dict.
Triton builds the cache key from the values of the declared ``key``
args (``M``, ``N``, ``K``) followed by dtype signature elements.
We label the first three and store the rest under ``_extra``.
"""
result: dict[str, Any] = {}
for i, name in enumerate(_KEY_NAMES):
if i < len(key_tuple):
result[name] = key_tuple[i]
if len(key_tuple) > len(_KEY_NAMES):
result["_extra"] = [str(v) for v in key_tuple[len(_KEY_NAMES) :]]
return result
def _find_lora_ops_module() -> ModuleType | None:
"""Locate the *runtime* ``lora_ops`` module in ``sys.modules``.
The HF ``kernels`` package loads ``scattermoe_lora`` via
``import_from_path`` which registers it in ``sys.modules`` under a
hash-suffixed name (e.g. ``scattermoe_lora_a1b2c3d4``). A normal
import (``from axolotl.integrations.kernels...``) would create a
*separate* module instance whose kernel objects have empty
``.cache`` dicts because autotuning ran on the runtime copy.
We search ``sys.modules`` for any module whose name contains
``lora_ops`` and that has the ``_scatter2scatter_lora`` kernel
attribute — that is the runtime copy with populated caches.
"""
for name, module in list(sys.modules.items()):
if (
module is not None
and "lora_ops" in name
and hasattr(module, "_scatter2scatter_lora")
):
return module
return None
def collect_autotune_configs() -> list[dict[str, Any]]:
"""Read autotune caches from the four scattermoe-lora kernels.
Returns a (possibly empty) list of dicts, each containing:
* ``kernel`` human-readable kernel name
* ``key`` dict with the ``M``/``N``/``K`` problem dimensions
* ``config`` dict with the selected tile sizes, ``num_warps``,
and ``num_stages``
Returns ``[]`` if the kernel module cannot be found or if no
autotune cache entries exist yet.
"""
lora_ops = _find_lora_ops_module()
if lora_ops is None:
LOG.debug(
"lora_ops module not found in sys.modules; skipping autotune collection"
)
return []
results: list[dict[str, Any]] = []
for friendly_name, attr_name in _KERNEL_REGISTRY:
kernel_fn = getattr(lora_ops, attr_name, None)
if kernel_fn is None:
continue
cache = getattr(kernel_fn, "cache", None)
if not cache:
continue
for key_tuple, config in cache.items():
config_dict = dict(config.kwargs)
config_dict["num_warps"] = config.num_warps
config_dict["num_stages"] = config.num_stages
if getattr(config, "num_ctas", None) is not None:
config_dict["num_ctas"] = config.num_ctas
results.append(
{
"kernel": friendly_name,
"key": _parse_key_tuple(key_tuple),
"config": config_dict,
}
)
return results

View File

@@ -25,6 +25,8 @@ SPARSE_MOE_BLOCK = {
"olmoe": "OlmoeSparseMoeBlock",
"mixtral": "MixtralSparseMoeBlock",
"minimax": "MiniMaxSparseMoeBlock",
# softmax -> topk routing (with group-based expert selection)
"mistral4": "Mistral4MoE",
# sigmoid -> topk routing (with group-based expert selection)
"glm_moe_dsa": "GlmMoeDsaMoE",
"deepseek_v3": "DeepseekV3MoE",

View File

@@ -195,6 +195,36 @@ def _estimate_smem_usage(
_SMEM_SLACK = 10_000
def _estimate_register_pressure(
num_warps: int,
*tile_sizes: tuple[int, int],
) -> float:
"""Rough estimate of per-thread register footprint from live tile sizes.
This is a heuristic, NOT an accurate register count. Triton uses tensor
core MMA fragments that pack multiple elements per register, and can spill
to local memory when the hardware limit (255 regs/thread) is exceeded.
The estimate is used to prune only truly extreme configs that would cause
excessive spilling or compilation failures. The threshold is set high
(``_MAX_REGS_SOFT_LIMIT``) because the heuristic overestimates — it
doesn't account for MMA fragment packing. Configs like M=64,N=64,K=64
(est ~520) work fine in practice via spilling.
Returns estimated registers per thread.
"""
# Each thread in a warp holds ~1/32 of the tile elements
tile_regs = sum(r * c for r, c in tile_sizes) / 32
scalar_overhead = 40
return tile_regs + scalar_overhead
# Soft limit for register pressure pruning. Only prune configs with extreme
# tile products (e.g. M=128,K=256,N=256) that reliably crash on Blackwell.
# Moderate configs (M=64,N=64,K=64, est ~520) work via register spilling.
_MAX_REGS_SOFT_LIMIT = 1024
# =============================================================================
# Forward Kernel: scatter2scatter with fused LoRA
# =============================================================================
@@ -313,12 +343,11 @@ def _compute_expert_block_lora(
B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0
) # [BLOCK_N, BLOCK_R]
# Cast xa_acc and b to same dtype for tl.dot (required when input is bf16/fp16)
# Both operands must match; cast to float32 (accumulator type) for precision.
b_f32 = b.to(tl.float32)
# tl.dot requires non-float32 inputs (tensor cores); cast back to input dtype
b_inp = b.to(INPUT_DTYPE)
# (X @ A^T) @ B^T: [M, R] @ [R, N] -> [M, N]
lora_out = tl.dot(xa_acc, tl.trans(b_f32), allow_tf32=allow_tf32)
lora_out = tl.dot(xa_acc.to(INPUT_DTYPE), tl.trans(b_inp), allow_tf32=allow_tf32)
acc += scaling * lora_out
return acc
@@ -327,20 +356,21 @@ def _compute_expert_block_lora(
def _scatter2scatter_lora_configs():
"""Generate forward kernel autotune configs.
Search space includes smaller tile sizes and fewer pipeline stages to
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
Search space includes BLOCK_M to allow trading token-tile size for
larger BLOCK_K/BLOCK_N tiles. On GPUs with ~99KB SMEM, BLOCK_M=128
forces BLOCK_K=32 and BLOCK_N=32; BLOCK_M=64 allows BLOCK_K=128
(4× fewer inner-loop iterations).
Search space:
BLOCK_M: {32, 64, 128}
BLOCK_N: {32, 64, 128, 256}
BLOCK_K: {32, 64, 128}
num_warps: {4, 8}
num_stages: {3, 4, 5}
BLOCK_M is fixed at 128 (module-level constant, not autotuned in the
scatter2scatter pattern).
"""
configs = []
for block_n, block_k, warps, stages in product(
for block_m, block_n, block_k, warps, stages in product(
[32, 64, 128], # BLOCK_M
[32, 64, 128, 256], # BLOCK_N
[32, 64, 128], # BLOCK_K
[4, 8], # num_warps
@@ -348,7 +378,7 @@ def _scatter2scatter_lora_configs():
):
configs.append(
triton.Config(
{"BLOCK_N": block_n, "BLOCK_K": block_k},
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k},
num_stages=stages,
num_warps=warps,
)
@@ -357,7 +387,7 @@ def _scatter2scatter_lora_configs():
def _prune_fwd_configs(configs, named_args, **kwargs):
"""Prune forward configs based on SMEM capacity.
"""Prune forward configs based on SMEM capacity and register pressure.
The forward kernel inner loop loads three tiles per pipeline stage:
X[BLOCK_M, BLOCK_K], W[BLOCK_K, BLOCK_N], A[BLOCK_R, BLOCK_K].
@@ -373,23 +403,49 @@ def _prune_fwd_configs(configs, named_args, **kwargs):
scored = []
for config in configs:
block_m = config.kwargs["BLOCK_M"]
block_n = config.kwargs["BLOCK_N"]
block_k = config.kwargs["BLOCK_K"]
# Base: stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_n, block_k)
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_n, block_k)
# A tile [BLOCK_R, BLOCK_K] loaded per stage in the inner loop
smem_lora_loop = config.num_stages * block_r * block_k * 2
# B tile [BLOCK_N, BLOCK_R] loaded once in epilogue
smem_lora_epilogue = block_n * block_r * 2
smem = smem_base + smem_lora_loop + smem_lora_epilogue
# Register pressure: live tiles are acc[M,N], xa_acc[M,R],
# x[M,K], w[K,N], a[R,K], plus epilogue b[N,R]
est_regs = _estimate_register_pressure(
config.num_warps,
(block_m, block_n), # acc
(block_m, block_r), # xa_acc
(block_m, block_k), # x tile
(block_k, block_n), # w tile
(block_r, block_k), # a tile
(block_n, block_r), # b tile (epilogue)
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
scored.append((smem, config))
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
if pruned:
return pruned
# All configs exceed SMEM — return the one with smallest estimated usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
if scored:
# All surviving configs exceed SMEM — return the one with smallest usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
# All configs pruned by register pressure — fall back to smallest tiles
return [
min(
configs,
key=lambda c: (
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_N"] * c.kwargs["BLOCK_K"]
),
)
]
@triton.autotune(
@@ -531,6 +587,89 @@ def _scatter2scatter_lora(
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
def _scatter2scatter_lora_split(
X: torch.Tensor,
W: torch.Tensor,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
k: int,
lora_A: torch.Tensor,
lora_B: torch.Tensor,
scaling: float,
b: Optional[torch.Tensor] = None,
x_grouped: bool = False,
y_grouped: bool = False,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Split base+LoRA forward: 3 scatter2scatter calls, no fused LoRA kernel.
Faster for models with few large experts (e.g. Mixtral E=8, I=14336)
because the base kernel runs at full speed without LoRA SMEM overhead,
and the LoRA matmuls (R=16) are tiny separate passes.
Y = scatter(X, W) + scaling * scatter(scatter(X, A^T), B^T)
"""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import (
scatter2scatter,
)
E = W.size(0)
R = lora_A.size(0) // E
K = W.size(1)
N = W.size(2)
# 1. Base: Y_base = X @ W (uses base kernel with optimal tile sizes)
output = scatter2scatter(
X=X,
W=W,
b=b,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=k,
x_grouped=x_grouped,
y_grouped=y_grouped,
out=out,
)
# 2. XA = X @ A^T (tiny: output is [M*k, R])
# Reshape A: [R*E, K] → [E, K, R] (expert weights for scatter2scatter)
W_A = lora_A.reshape(E, R, K).permute(0, 2, 1).contiguous()
XA = scatter2scatter(
X=X,
W=W_A,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=k,
x_grouped=x_grouped,
y_grouped=True,
)
# 3. Y_lora = XA @ B^T (R is tiny, so this is very fast)
# Reshape B: [N, R*E] → [E, R, N]
W_B = lora_B.T.reshape(E, R, N).contiguous()
Y_lora = scatter2scatter(
X=XA,
W=W_B,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
x_grouped=True,
y_grouped=y_grouped,
)
# 4. Y = Y_base + scaling * Y_lora
output.add_(Y_lora, alpha=scaling)
return output
# Threshold for switching from fused to split LoRA forward.
# Split wins when per-expert matmul is large (bandwidth-bound LoRA tile
# loads dominate in the fused kernel's inner loop).
# Empirically: split wins for E<=32 with K*N > 20M (e.g. Mixtral, Phi-MoE).
_SPLIT_LORA_FWD_THRESHOLD = 20_000_000 # per-expert K*N
_SPLIT_LORA_FWD_MAX_EXPERTS = 32
def scatter2scatter_lora(
X: torch.Tensor,
W: torch.Tensor,
@@ -546,7 +685,13 @@ def scatter2scatter_lora(
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Fused scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
Scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
Automatically selects between:
- Fused kernel: single Triton kernel with LoRA in the inner loop.
Best for many small experts (E>=64, small K*N).
- Split dispatch: 3 separate scatter2scatter calls (base + XA + lora).
Best for few large experts (E<=32, large K*N like Mixtral).
Args:
X: Input [M, K] or [M*k, K] if x_grouped
@@ -565,12 +710,30 @@ def scatter2scatter_lora(
Returns:
Y: Output [M*k, N]
"""
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
assert sorted_scattered_idxs.size(0) == X.size(0) * k
E = W.size(0)
K = W.size(1)
N = W.size(2)
# Dispatch: split for few large experts, fused for many small experts
if E <= _SPLIT_LORA_FWD_MAX_EXPERTS and K * N >= _SPLIT_LORA_FWD_THRESHOLD:
return _scatter2scatter_lora_split(
X,
W,
sorted_expert_idxs,
sorted_scattered_idxs,
k,
lora_A,
lora_B,
scaling,
b,
x_grouped,
y_grouped,
out,
)
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
assert sorted_scattered_idxs.size(0) == X.size(0) * k
R = lora_A.size(0) // E
# Pad R to power of 2 for Triton tile size
@@ -610,11 +773,9 @@ def scatter2scatter_lora(
b_ptr,
stride_be,
stride_bn,
# A: [r*E, K] -> stride(0) is r*E dim stride, stride(1) is K dim stride
lora_A,
lora_A.stride(0),
lora_A.stride(1),
# B: [N, r*E] -> stride(0) is N dim stride, stride(1) is r*E dim stride
lora_B,
lora_B.stride(0),
lora_B.stride(1),
@@ -625,9 +786,8 @@ def scatter2scatter_lora(
K=K,
N=N,
E=E,
ACTUAL_R=R, # True LoRA rank for weight indexing
BLOCK_M=BLOCK_M,
BLOCK_R=BLOCK_R, # Padded tile size >= max(R, 16)
ACTUAL_R=R,
BLOCK_R=BLOCK_R,
ACC_TYPE=tl.float32,
scaling=scaling,
allow_tf32=ALLOW_TF32,
@@ -761,13 +921,13 @@ def _compute_expert_block_lora_dX(
+ (A_expert_offset + R_block)[:, None] * stride_ar
+ K_block[None, :] * stride_ak
)
a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0)
# Cast to float32 for precision
a_f32 = a_e.to(tl.float32)
a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0).to(
INPUT_DTYPE
)
# (DY @ B) @ A: [M, R] @ [R, K] -> [M, K]
lora_dx = tl.dot(dy_b_acc, a_f32, allow_tf32=allow_tf32)
# tl.dot requires non-float32 inputs (tensor cores); cast accumulator back to input dtype
lora_dx = tl.dot(dy_b_acc.to(INPUT_DTYPE), a_e, allow_tf32=allow_tf32)
acc += scaling * lora_dx
return acc
@@ -779,17 +939,18 @@ def _scatter2scatter_lora_dX_configs():
The inner loop is over N (not K as in forward). The output dimension is K.
So BLOCK_K tiles the output and BLOCK_N tiles the reduction.
Search space includes smaller tile sizes and fewer pipeline stages to
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
BLOCK_M is now autotunable (was fixed at 128).
Search space:
BLOCK_M: {32, 64, 128} (token tile)
BLOCK_K: {32, 64, 128, 256} (output tile)
BLOCK_N: {32, 64, 128, 256} (reduction tile)
num_warps: {4, 8}
num_stages: {3, 4, 5}
"""
configs = []
for block_k, block_n, warps, stages in product(
for block_m, block_k, block_n, warps, stages in product(
[32, 64, 128], # BLOCK_M
[32, 64, 128, 256], # BLOCK_K (output dimension)
[32, 64, 128, 256], # BLOCK_N (reduction dimension)
[4, 8], # num_warps
@@ -797,7 +958,7 @@ def _scatter2scatter_lora_dX_configs():
):
configs.append(
triton.Config(
{"BLOCK_K": block_k, "BLOCK_N": block_n},
{"BLOCK_M": block_m, "BLOCK_K": block_k, "BLOCK_N": block_n},
num_stages=stages,
num_warps=warps,
)
@@ -806,7 +967,7 @@ def _scatter2scatter_lora_dX_configs():
def _prune_dX_configs(configs, named_args, **kwargs):
"""Prune backward dX configs based on SMEM capacity.
"""Prune backward dX configs based on SMEM capacity and register pressure.
The dX kernel inner loop loads three tiles per pipeline stage:
DY[BLOCK_M, BLOCK_N], W^T[BLOCK_N, BLOCK_K], B[BLOCK_N, BLOCK_R].
@@ -822,23 +983,49 @@ def _prune_dX_configs(configs, named_args, **kwargs):
scored = []
for config in configs:
block_m = config.kwargs["BLOCK_M"]
block_k = config.kwargs["BLOCK_K"]
block_n = config.kwargs["BLOCK_N"]
# Base: stages * BLOCK_N * (BLOCK_M + BLOCK_K) + BLOCK_M * BLOCK_K
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_k, block_n)
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_k, block_n)
# B tile [BLOCK_N, BLOCK_R] loaded per stage in the inner loop
smem_lora_loop = config.num_stages * block_n * block_r * 2
# A tile [BLOCK_R, BLOCK_K] loaded once in epilogue
smem_lora_epilogue = block_r * block_k * 2
smem = smem_base + smem_lora_loop + smem_lora_epilogue
# Register pressure: live tiles are acc[M,K], dy_b_acc[M,R],
# dy[M,N], wt[N,K], b[N,R], plus epilogue a[R,K]
est_regs = _estimate_register_pressure(
config.num_warps,
(block_m, block_k), # acc
(block_m, block_r), # dy_b_acc
(block_m, block_n), # dy tile
(block_n, block_k), # wt tile
(block_n, block_r), # b tile
(block_r, block_k), # a tile (epilogue)
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
scored.append((smem, config))
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
if pruned:
return pruned
# All configs exceed SMEM — return the one with smallest estimated usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
if scored:
# All surviving configs exceed SMEM — return the one with smallest usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
# All configs pruned by register pressure — fall back to smallest tiles
return [
min(
configs,
key=lambda c: (
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"]
),
)
]
@triton.autotune(
@@ -1067,7 +1254,7 @@ def scatter2scatter_lora_dX(
N=N,
E=E,
ACTUAL_R=R,
BLOCK_M=BLOCK_M,
# BLOCK_M is autotuned (injected by triton.autotune from Config kwargs)
BLOCK_R=BLOCK_R,
ACC_TYPE=tl.float32,
scaling=scaling,
@@ -1119,7 +1306,7 @@ def _group_bwd_lora_configs():
def _prune_bwd_lora_configs(configs, named_args, **kwargs):
"""Prune backward configs based on SMEM capacity.
"""Prune backward configs based on SMEM capacity and register pressure.
The backward kernel loads X[BLOCK_M, BLOCK_K] and DY[BLOCK_M, BLOCK_N]
in the inner loop, plus holds A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R]
@@ -1138,14 +1325,40 @@ def _prune_bwd_lora_configs(configs, named_args, **kwargs):
# A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R] held for the full expert
smem_lora = (block_r * block_k + block_n * block_r) * 2
smem = smem_base + smem_lora
# Register pressure: dA_acc[R,K], dB_acc[N,R], x[M,K], dy[M,N],
# a[R,K], b[N,R], xa[M,R], dy_b[M,R]
est_regs = _estimate_register_pressure(
config.num_warps,
(block_r, block_k), # dA_acc
(block_n, block_r), # dB_acc
(block_m, block_k), # x tile
(block_m, block_n), # dy tile
(block_r, block_k), # a tile
(block_n, block_r), # b tile
(block_m, block_r), # xa intermediate
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
scored.append((smem, config))
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
if pruned:
return pruned
# All configs exceed SMEM — return the one with smallest estimated usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
if scored:
# All surviving configs exceed SMEM — return the one with smallest usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
# All configs pruned by register pressure — fall back to smallest tiles
return [
min(
configs,
key=lambda c: (
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"]
),
)
]
@triton.autotune(
@@ -1330,6 +1543,279 @@ def _group_bwd_lora(
)
def _group_bwd_split_configs():
"""Autotune configs for split dA/dB kernels."""
configs = []
for block_m, block_dim, warps, stages in product(
[32, 64, 128], # BLOCK_M (token tile)
[32, 64, 128, 256], # BLOCK_DIM (K for dA, N for dB — output tile)
[4, 8], # num_warps
[3, 4, 5], # num_stages
):
configs.append(
triton.Config(
{"BLOCK_M": block_m, "BLOCK_DIM": block_dim},
num_stages=stages,
num_warps=warps,
)
)
return configs
def _prune_split_configs(configs, named_args, **kwargs):
"""Prune split kernel configs based on SMEM capacity and register pressure."""
smem_cap = _get_smem_capacity()
block_r = named_args.get("BLOCK_R", 64)
# Fixed inner tile for reduction dimension
BLOCK_INNER = 64
pruned = []
for config in configs:
block_m = config.kwargs["BLOCK_M"]
block_dim = config.kwargs["BLOCK_DIM"]
# Inner loop loads: input[M, INNER] and other[M, INNER_or_DIM]
smem = config.num_stages * BLOCK_INNER * (block_m + block_dim) * 2
# LoRA weights held in registers: [INNER, R] or [R, DIM]
smem += (block_r * max(block_dim, BLOCK_INNER)) * 2
# Register pressure check
est_regs = _estimate_register_pressure(
config.num_warps,
(block_r, block_dim), # acc
(block_m, BLOCK_INNER), # input tile
(block_m, block_dim), # other tile
(block_r, BLOCK_INNER), # lora weight
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
if smem <= smem_cap - _SMEM_SLACK:
pruned.append(config)
if pruned:
return pruned
configs.sort(key=lambda c: c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_DIM"])
return [configs[0]]
@triton.autotune(
configs=_group_bwd_split_configs(),
key=["M", "K", "N"],
prune_configs_by={"early_config_prune": _prune_split_configs},
)
@triton.heuristics(
{
"NO_DIM_MASK": lambda args: (
(args["K"] % args["BLOCK_DIM"]) == 0
if args["COMPUTE_DA"]
else (args["N"] % args["BLOCK_DIM"]) == 0
),
}
)
@triton.jit
def _group_bwd_lora_split(
# Data tensors (DY and X are always present)
DY_ptr,
stride_dym,
stride_dyn,
X_ptr,
stride_xm,
stride_xk,
# LoRA weight for the inner reduction (B for dA, A for dB)
LW_ptr,
stride_lw0,
stride_lw1,
# Output gradient tensor (dA or dB)
OUT_ptr,
stride_out0,
stride_out1,
# Expert offsets
expert_offsets_ptr,
# Dimensions
M,
K: tl.constexpr,
N: tl.constexpr,
ACTUAL_R: tl.constexpr,
BLOCK_R: tl.constexpr,
INNER_DIM: tl.constexpr, # reduction dimension (N for dA, K for dB)
scaling,
# Mode flag
COMPUTE_DA: tl.constexpr, # True = compute dA, False = compute dB
# Tile sizes
BLOCK_M: tl.constexpr,
BLOCK_DIM: tl.constexpr,
ACC_TYPE: tl.constexpr,
allow_tf32: tl.constexpr,
NO_DIM_MASK: tl.constexpr,
):
"""
Unified split kernel for LoRA gradient computation.
When COMPUTE_DA=True:
dA[e] = scaling * (dY @ B[e])^T @ X → [R, K]
Grid: (E, cdiv(K, BLOCK_DIM))
- outer_ptr/stride = X (read [M, K_block])
- inner reduction over N using DY and B
- output shape [BLOCK_R, BLOCK_DIM]
When COMPUTE_DA=False:
dB[e] = scaling * dY^T @ (X @ A[e]^T) → [N, R]
Grid: (E, cdiv(N, BLOCK_DIM))
- outer_ptr/stride = DY (read [M, N_block])
- inner reduction over K using X and A
- output shape [BLOCK_DIM, BLOCK_R]
No atomic adds — each (E, dim_block) pair is written by exactly one block.
"""
E_idx = tl.program_id(0)
dim_block_id = tl.program_id(1)
if E_idx == 0:
start_idx = 0
else:
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
num_tokens = end_idx - start_idx
# Output dimension tile (K for dA, N for dB)
if COMPUTE_DA:
OUT_DIM: tl.constexpr = K # type: ignore[no-redef]
else:
OUT_DIM: tl.constexpr = N # type: ignore[no-redef]
dim_block = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
dim_mask = dim_block < OUT_DIM
R_block = tl.arange(0, BLOCK_R)
R_mask = R_block < ACTUAL_R
lora_offset = E_idx * ACTUAL_R
# Output pointers — layout differs: dA is [R, K], dB is [N, R]
if COMPUTE_DA:
out_blk_ptrs = (
OUT_ptr
+ (lora_offset + R_block)[:, None] * stride_out0
+ dim_block[None, :] * stride_out1
)
out_mask = R_mask[:, None] & dim_mask[None, :]
else:
out_blk_ptrs = (
OUT_ptr
+ dim_block[:, None] * stride_out0
+ (lora_offset + R_block)[None, :] * stride_out1
)
out_mask = dim_mask[:, None] & R_mask[None, :]
if num_tokens > 0:
M_block = tl.arange(0, BLOCK_M)
INPUT_DTYPE = X_ptr.dtype.element_ty
BLOCK_INNER: tl.constexpr = 64
inner_iters = tl.cdiv(INNER_DIM, BLOCK_INNER)
if COMPUTE_DA:
acc = tl.zeros((BLOCK_R, BLOCK_DIM), dtype=ACC_TYPE)
else:
acc = tl.zeros((BLOCK_DIM, BLOCK_R), dtype=ACC_TYPE)
M_iters = tl.cdiv(num_tokens, BLOCK_M)
for i in range(M_iters):
M_idx = start_idx + i * BLOCK_M + M_block
M_mask = M_idx < end_idx
if COMPUTE_DA:
# Load X[M, K_block] (the "outer" tensor for dA)
outer = tl.load(
X_ptr + M_idx[:, None] * stride_xm + dim_block[None, :] * stride_xk,
mask=M_mask[:, None] & dim_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# Reduce DY[M, :] @ B[e][:, R] over N → [M, R]
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
inner_range = tl.arange(0, BLOCK_INNER)
for j in range(inner_iters):
inn_off = j * BLOCK_INNER + inner_range
inn_mask = inn_off < N
dy_tile = tl.load(
DY_ptr
+ M_idx[:, None] * stride_dym
+ inn_off[None, :] * stride_dyn,
mask=M_mask[:, None] & inn_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# B layout: [N, r*E] → stride_lw0=N stride, stride_lw1=r*E stride
lw_tile = tl.load(
LW_ptr
+ inn_off[:, None] * stride_lw0
+ (lora_offset + R_block)[None, :] * stride_lw1,
mask=inn_mask[:, None] & R_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
reduced += tl.dot(dy_tile, lw_tile, allow_tf32=allow_tf32)
# dA += (DY@B)^T @ X: [R, M] @ [M, K_block] → [R, K_block]
acc += tl.dot(
tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32
)
else:
# Load DY[M, N_block] (the "outer" tensor for dB)
outer = tl.load(
DY_ptr
+ M_idx[:, None] * stride_dym
+ dim_block[None, :] * stride_dyn,
mask=M_mask[:, None] & dim_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# Reduce X[M, :] @ A[e][:, :].T over K → [M, R]
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
inner_range = tl.arange(0, BLOCK_INNER)
for j in range(inner_iters):
inn_off = j * BLOCK_INNER + inner_range
inn_mask = inn_off < K
x_tile = tl.load(
X_ptr
+ M_idx[:, None] * stride_xm
+ inn_off[None, :] * stride_xk,
mask=M_mask[:, None] & inn_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# A layout: [r*E, K] → stride_lw0=r*E stride, stride_lw1=K stride
# We want A[e]^T: [K, R], so load as [K_inner, R]
lw_tile = tl.load(
LW_ptr
+ (lora_offset + R_block)[None, :] * stride_lw0
+ inn_off[:, None] * stride_lw1,
mask=inn_mask[:, None] & R_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
reduced += tl.dot(x_tile, lw_tile, allow_tf32=allow_tf32)
# dB += DY^T @ (X@A^T): [N_block, M] @ [M, R] → [N_block, R]
acc += tl.dot(
tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32
)
tl.store(
out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask
)
else:
# Zero out this expert's slice — needed because output uses empty_like
if COMPUTE_DA:
tl.store(
out_blk_ptrs,
tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty),
mask=out_mask,
)
else:
tl.store(
out_blk_ptrs,
tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty),
mask=out_mask,
)
def group_bwd_lora(
DY: torch.Tensor,
X: torch.Tensor,
@@ -1344,6 +1830,9 @@ def group_bwd_lora(
"""
Compute LoRA gradients for A and B on expert-grouped data.
Uses split dA/dB kernels that eliminate atomic adds by giving each
(expert, output_block) pair its own thread block.
Args:
DY: Gradient w.r.t. output [M_total, N] (grouped by expert)
X: Input [M_total, K] (grouped by expert)
@@ -1361,19 +1850,46 @@ def group_bwd_lora(
K = X.size(1)
N = DY.size(1)
# Zero-init for atomic accumulation
dA = torch.zeros_like(lora_A)
dB = torch.zeros_like(lora_B)
# No zero-init needed: the split kernels write zeros for experts with
# zero routed tokens directly in the kernel (else branch).
dA = torch.empty_like(lora_A)
dB = torch.empty_like(lora_B)
BLOCK_R = _block_r_for_rank(R)
def grid(META):
return (
E * triton.cdiv(K, META["BLOCK_K"]),
triton.cdiv(N, META["BLOCK_N"]),
)
def grid_dA(META):
return (E, triton.cdiv(K, META["BLOCK_DIM"]))
_group_bwd_lora[grid](
_group_bwd_lora_split[grid_dA](
DY,
DY.stride(0),
DY.stride(1),
X,
X.stride(0),
X.stride(1),
lora_B,
lora_B.stride(0),
lora_B.stride(1),
dA,
dA.stride(0),
dA.stride(1),
expert_offsets,
M=DY.size(0),
K=K,
N=N,
ACTUAL_R=R,
BLOCK_R=BLOCK_R,
INNER_DIM=N,
scaling=scaling,
COMPUTE_DA=True,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
)
def grid_dB(META):
return (E, triton.cdiv(N, META["BLOCK_DIM"]))
_group_bwd_lora_split[grid_dB](
DY,
DY.stride(0),
DY.stride(1),
@@ -1383,12 +1899,6 @@ def group_bwd_lora(
lora_A,
lora_A.stride(0),
lora_A.stride(1),
lora_B,
lora_B.stride(0),
lora_B.stride(1),
dA,
dA.stride(0),
dA.stride(1),
dB,
dB.stride(0),
dB.stride(1),
@@ -1396,9 +1906,11 @@ def group_bwd_lora(
M=DY.size(0),
K=K,
N=N,
ACTUAL_R=R, # True LoRA rank
BLOCK_R=BLOCK_R, # Padded tile size
ACTUAL_R=R,
BLOCK_R=BLOCK_R,
INNER_DIM=K,
scaling=scaling,
COMPUTE_DA=False,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
)

View File

@@ -220,6 +220,158 @@ def _unwrap_experts_lora(experts_module):
return base_experts, gup_lora, down_lora
# =============================================================================
# Routing helpers
# =============================================================================
def _softmax_topk_route(
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
):
"""Softmax→topk routing (Qwen, OLMoE, Mixtral, MiniMax).
Returns:
(routing_weights [T, K], selected_experts [T, K], top_k, num_experts)
"""
router_logits = F.linear(hidden_states, gate_weight)
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(hidden_states, gate_lora_delta)
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
top_k = base_gate.top_k
num_experts = base_gate.num_experts
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
if getattr(base_gate, "norm_topk_prob", True):
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
return routing_weights, selected_experts, top_k, num_experts
def _sigmoid_topk_route(
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
):
"""Sigmoid→topk routing (GLM, DeepSeek V3, MiniMax M2).
Supports:
- ``e_score_correction_bias`` on gate or moe_block
- Group-based expert selection when ``n_group > 1``
- ``routed_scaling_factor`` applied to final weights
- Final weights gathered from original sigmoid probs (not bias-corrected)
Returns:
(routing_weights [T, K], selected_experts [T, K], top_k, num_experts)
"""
router_logits = F.linear(hidden_states.float(), gate_weight.float())
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states.float(), gate_lora_delta.float()
)
router_probs = router_logits.sigmoid() # [T, E]
top_k = getattr(moe_block, "top_k", getattr(base_gate, "top_k", None))
num_experts = getattr(moe_block, "n_routed_experts", gate_weight.shape[0])
# Bias-corrected scores for expert selection (not used for final weights).
# glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 on the block.
e_score_correction_bias = getattr(base_gate, "e_score_correction_bias", None)
if e_score_correction_bias is None:
e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None)
if e_score_correction_bias is not None:
scores_for_choice = router_probs + e_score_correction_bias
else:
scores_for_choice = router_probs
# Group-based selection: pick top groups, mask the rest
n_group = getattr(moe_block, "n_group", 1)
if n_group > 1:
group_scores = (
scores_for_choice.view(-1, n_group, num_experts // n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
) # [T, n_group]
topk_group = getattr(moe_block, "topk_group", n_group)
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(-1, n_group, num_experts // n_group)
.reshape(-1, num_experts)
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
# Final topk from (possibly masked) scores
topk_indices = torch.topk(scores_for_choice, k=top_k, dim=-1, sorted=False)[1]
# Gather weights from original sigmoid scores (not bias-corrected)
topk_weights = router_probs.gather(1, topk_indices)
# Optional renormalization + scaling
if getattr(moe_block, "norm_topk_prob", True):
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
topk_weights = topk_weights * routed_scaling_factor
return topk_weights, topk_indices, top_k, num_experts
def _route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):
"""Dispatch to the correct routing strategy based on block attributes.
Detects sigmoid routing by the presence of ``e_score_correction_bias``
on either the gate or the moe_block.
"""
has_sigmoid = (
getattr(base_gate, "e_score_correction_bias", None) is not None
or getattr(moe_block, "e_score_correction_bias", None) is not None
)
if has_sigmoid:
return _sigmoid_topk_route(
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
)
return _softmax_topk_route(
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
)
# =============================================================================
# Shared expert helpers
# =============================================================================
def _compute_shared_expert(moe_block, hidden_states_flat):
"""Compute shared expert output if the block has one.
Handles singular (qwen2_moe: ``shared_expert``), plural
(glm_moe_dsa/deepseek_v3: ``shared_experts``), and MLP
(hunyuan_v1_moe: ``shared_mlp``) attribute names.
peft wraps individual linear layers inside the shared expert with
standard LoRA — calling forward() handles this transparently.
"""
shared_expert = (
getattr(moe_block, "shared_expert", None)
or getattr(moe_block, "shared_experts", None)
or getattr(moe_block, "shared_mlp", None)
)
if shared_expert is None:
return None
shared_expert_output = shared_expert(hidden_states_flat)
# Optional sigmoid gate (Qwen2MoE pattern).
# shared_expert_gate may also be peft-wrapped (standard LoRA
# on nn.Linear), its forward() applies LoRA automatically.
shared_expert_gate = getattr(moe_block, "shared_expert_gate", None)
if shared_expert_gate is not None:
shared_expert_output = (
F.sigmoid(shared_expert_gate(hidden_states_flat)) * shared_expert_output
)
return shared_expert_output
# =============================================================================
# Layer classes
# =============================================================================
@@ -281,16 +433,18 @@ class ScatterMoEGatedMLP(nn.Module):
class HFScatterMoEGatedMLP(nn.Module):
"""
ScatterMoE-accelerated forward pass for HF MoEs (OLMoE / Qwen2MoE).
ScatterMoE-accelerated forward pass for HF MoEs.
Used as a kernel layer via the HF ``kernels`` library. The ``forward``
method replaces the original ``OlmoeSparseMoeBlock.forward``.
method replaces the original SparseMoeBlock.forward.
Supports both full-parameter training and LoRA fine-tuning:
Supports:
* **Full-param**: uses ``parallel_linear`` (base ScatterMoE kernel)
* **LoRA**: detects peft ``ParamWrapper`` on ``self.experts``, extracts
adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
* **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2
* **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE)
* **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``,
extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
"""
@staticmethod
@@ -302,7 +456,7 @@ class HFScatterMoEGatedMLP(nn.Module):
self: The MoeSparseMoeBlock module containing:
- self.gate: Router (or peft ParamWrapper wrapping it)
- self.experts: Experts module (or peft ParamWrapper chain)
- self.shared_expert: Optional shared expert (e.g. Qwen2MoE)
- self.shared_expert(s): Optional shared expert
- self.shared_expert_gate: Optional shared expert gate
layer_input: Input tensor [batch_size, seq_len, hidden_size]
@@ -313,38 +467,17 @@ class HFScatterMoEGatedMLP(nn.Module):
hidden_states_flat = layer_input.view(-1, hidden_dim)
# ====================================================================
# Shared Expert (if present, e.g. Qwen2MoE)
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3)
# ====================================================================
# peft wraps individual linear layers inside shared_expert with
# standard LoRA — calling forward() handles this transparently.
if hasattr(self, "shared_expert") and self.shared_expert is not None:
shared_expert_output = self.shared_expert(hidden_states_flat)
# shared_expert_gate may also be peft-wrapped (standard LoRA
# on nn.Linear), its forward() applies LoRA automatically.
shared_expert_gate_output = F.sigmoid(
self.shared_expert_gate(hidden_states_flat)
)
shared_expert_output = shared_expert_output * shared_expert_gate_output
else:
shared_expert_output = None
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
# ====================================================================
# Router Computation (with optional gate LoRA)
# ====================================================================
base_gate, gate_weight, gate_lora_delta = _unwrap_gate_lora(self.gate)
router_logits = F.linear(hidden_states_flat, gate_weight)
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states_flat, gate_lora_delta
)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
top_k = base_gate.top_k
num_experts = base_gate.num_experts
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
if base_gate.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights, selected_experts, top_k, num_experts = _route(
self, base_gate, hidden_states_flat, gate_weight, gate_lora_delta
)
routing_weights = routing_weights.to(hidden_states_flat.dtype)
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
@@ -356,20 +489,71 @@ class HFScatterMoEGatedMLP(nn.Module):
# ====================================================================
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
# ====================================================================
# Selective expert weight dequantization
# ====================================================================
# When experts are BnB-quantized (quantize_moe_experts), dequantize
# only the active experts instead of all E. This saves ~97% memory
# for the transient dequant buffer when few experts are active.
use_selective = (
getattr(self, "_use_selective_dequant", False)
and hasattr(experts, "parametrizations")
and "gate_up_proj" in experts.parametrizations
)
if use_selective:
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant import (
get_active_experts,
remap_expert_indices,
selective_expert_weights,
selective_lora_weights,
)
active_experts = get_active_experts(sorted_expert_idxs, num_experts)
remapped_expert_idxs, compact_offsets = remap_expert_indices(
sorted_expert_idxs,
expert_offsets,
active_experts,
num_experts,
)
# Dequantize only active experts' weights
gate_up_W = selective_expert_weights(
experts,
"gate_up_proj",
active_experts,
).transpose(2, 1) # [num_active, hidden, 2*inter]
# Remap LoRA weights to match compact expert indices
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
gup_A, gup_B = selective_lora_weights(
gup_A,
gup_B,
active_experts,
num_experts,
)
gup_lora = (gup_A, gup_B, gup_scaling)
# Use remapped indices for ScatterMoE kernels
sei_gup = remapped_expert_idxs
eo_gup = compact_offsets
else:
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
sei_gup = sorted_expert_idxs
eo_gup = expert_offsets
# ====================================================================
# Gate + Up projection
# ====================================================================
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
gup = parallel_linear_lora(
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sei_gup,
sorted_scattered_idxs,
expert_offsets,
eo_gup,
lora_A=gup_A,
lora_B=gup_B,
scaling=gup_scaling,
@@ -383,9 +567,9 @@ class HFScatterMoEGatedMLP(nn.Module):
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sei_gup,
sorted_scattered_idxs,
expert_offsets,
eo_gup,
grouped_in=False,
grouped_out=True,
)
@@ -396,7 +580,29 @@ class HFScatterMoEGatedMLP(nn.Module):
# ====================================================================
# Down projection
# ====================================================================
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
if use_selective:
down_W = selective_expert_weights(
experts,
"down_proj",
active_experts,
).transpose(2, 1) # [num_active, inter, hidden]
if down_lora is not None:
down_A, down_B, down_scaling = down_lora
down_A, down_B = selective_lora_weights(
down_A,
down_B,
active_experts,
num_experts,
)
down_lora = (down_A, down_B, down_scaling)
sei_down = remapped_expert_idxs
eo_down = compact_offsets
else:
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
sei_down = sorted_expert_idxs
eo_down = expert_offsets
if down_lora is not None:
down_A, down_B, down_scaling = down_lora
@@ -404,9 +610,9 @@ class HFScatterMoEGatedMLP(nn.Module):
h,
down_W,
1,
sorted_expert_idxs,
sei_down,
sorted_scattered_idxs,
expert_offsets,
eo_down,
lora_A=down_A,
lora_B=down_B,
scaling=down_scaling,
@@ -421,9 +627,9 @@ class HFScatterMoEGatedMLP(nn.Module):
h,
down_W,
1,
sorted_expert_idxs,
sei_down,
sorted_scattered_idxs,
expert_offsets,
eo_down,
grouped_in=True,
grouped_out=False,
gates=routing_weights,

View File

@@ -0,0 +1,282 @@
"""
Selective Expert Dequantization
===============================
Instead of dequantizing all E expert weight matrices at once (which creates
a ~1 GB transient buffer for 256 experts), only dequantize the experts that
are actually routed to by the current batch's top-k selection.
For Qwen3.5-35B-A3B (E=256, top_k=8, hidden=2048, intermediate=512):
- Full dequant: [256, 2048, 1024] = 1,074 MB per projection
- Selective (8 active): [8, 2048, 1024] = 33.5 MB per projection
- Savings: ~97% memory reduction per layer
This module provides format-agnostic selective weight extraction:
- BnB 4-bit (nf4/fp4): slice quantized data + absmax per expert
- bf16/fp32: direct indexing (no dequant needed)
- FP8: slice + cast
The ScatterMoE kernel itself doesn't change — we remap expert indices
from global (0..E-1) to compact (0..num_active-1) and pass the smaller
weight tensor.
"""
import torch
import torch.nn as nn
def get_active_experts(sorted_expert_idxs: torch.Tensor, E: int) -> torch.Tensor:
"""Get sorted unique expert indices from the routing output.
Args:
sorted_expert_idxs: Expert assignments sorted by expert id [T*k]
E: Total number of experts
Returns:
active: Sorted unique expert indices [num_active]
"""
return torch.unique(sorted_expert_idxs)
def remap_expert_indices(
sorted_expert_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
active_experts: torch.Tensor,
E: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Remap global expert indices to compact indices.
Maps expert ids from [0..E-1] to [0..num_active-1], preserving the
sort order. Also compacts expert_offsets to only active experts.
Args:
sorted_expert_idxs: [T*k] expert ids in sorted order
expert_offsets: [E] cumulative token counts (original)
active_experts: [num_active] sorted unique expert ids
E: Total number of experts
Returns:
remapped_idxs: [T*k] expert ids in [0..num_active-1]
compact_offsets: [num_active] cumulative token counts
"""
# Build remap table: global_id -> compact_id
remap = torch.empty(E, dtype=torch.long, device=sorted_expert_idxs.device)
remap[active_experts] = torch.arange(
len(active_experts), device=sorted_expert_idxs.device
)
remapped_idxs = remap[sorted_expert_idxs]
# Compact the expert_offsets: only keep active experts' cumulative counts
compact_offsets = expert_offsets[active_experts]
return remapped_idxs, compact_offsets
def _selective_dequant_bnb4(
raw_param: torch.Tensor,
quant_state,
active_experts: torch.Tensor,
expert_shape: tuple[int, int],
) -> torch.Tensor:
"""Dequantize only selected experts from BnB 4-bit packed data.
The raw parameter is a flattened 4-bit packed tensor. Each expert's
data is contiguous (stored in expert-major order), so we can gather
the packed data and absmax blocks for active experts, then dequantize
as one contiguous block.
Args:
raw_param: Flattened uint8 tensor of packed 4-bit weights
quant_state: BnB QuantState with absmax, blocksize, code, etc.
active_experts: [num_active] expert indices to dequantize
expert_shape: (dim1, dim2) shape per expert (e.g. (1024, 2048))
Returns:
Dequantized weights [num_active, dim1, dim2] in original dtype
"""
import bitsandbytes.functional as F # noqa: N812
from bitsandbytes.functional import QuantState
expert_numel = expert_shape[0] * expert_shape[1]
packed_per_expert = expert_numel // 2 # 4-bit = 2 values per byte
blocks_per_expert = expert_numel // quant_state.blocksize
num_active = len(active_experts)
if blocks_per_expert == 0:
# Expert is smaller than one quantization block — blocks span across
# expert boundaries, so per-expert slicing isn't possible.
# Fallback: full dequantize + index.
full = F.dequantize_4bit(raw_param, quant_state)
E_total = full.numel() // expert_numel
return full.reshape(E_total, *expert_shape)[active_experts]
# Use fused Triton kernel for NF4 (handles selective gather + dequant in one pass)
if quant_state.quant_type == "nf4" and raw_param.dtype == torch.uint8:
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant_kernel import (
selective_dequant_nf4_triton,
)
# Handle nested (double) quantization: dequantize absmax first
# BnB uses dequantize_blockwise (not _4bit) for nested absmax + offset
if quant_state.nested:
absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()
else:
absmax = quant_state.absmax
return selective_dequant_nf4_triton(
packed_data=raw_param,
absmax=absmax,
active_experts=active_experts,
expert_shape=expert_shape,
blocksize=quant_state.blocksize,
dtype=quant_state.dtype,
codebook=quant_state.code,
)
# Fallback: gather + BnB dequant (for fp4 or non-uint8 packed formats)
raw_flat = raw_param.reshape(-1)
offsets_qt = (
active_experts.long()[:, None] * packed_per_expert
+ torch.arange(packed_per_expert, device=raw_param.device)[None, :]
).reshape(-1)
qt_gathered = raw_flat[offsets_qt]
offsets_abs = (
active_experts.long()[:, None] * blocks_per_expert
+ torch.arange(blocks_per_expert, device=raw_param.device)[None, :]
).reshape(-1)
if quant_state.nested:
full_absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
full_absmax += quant_state.offset
if full_absmax.dtype != torch.float32:
full_absmax = full_absmax.float()
absmax_gathered = full_absmax[offsets_abs]
else:
absmax_gathered = quant_state.absmax[offsets_abs]
qt_gathered = qt_gathered.unsqueeze(1) if qt_gathered.dim() == 1 else qt_gathered
gathered_qs = QuantState(
absmax=absmax_gathered,
shape=torch.Size([num_active * expert_numel]),
blocksize=quant_state.blocksize,
quant_type=quant_state.quant_type,
code=quant_state.code,
dtype=quant_state.dtype,
)
deq = F.dequantize_4bit(qt_gathered, gathered_qs)
return deq.reshape(num_active, *expert_shape)
def _selective_index_dense(
param: torch.Tensor,
active_experts: torch.Tensor,
) -> torch.Tensor:
"""Select experts from a dense (bf16/fp32) weight tensor.
Simple indexing — no dequantization needed.
"""
return param[active_experts]
def selective_expert_weights(
experts_module: nn.Module,
param_name: str,
active_experts: torch.Tensor,
) -> torch.Tensor:
"""Extract and dequantize only the active experts' weights.
Format-agnostic: dispatches based on whether the parameter is
BnB 4-bit quantized (via parametrize), FP8, or dense bf16/fp32.
Args:
experts_module: The base experts module (e.g. Qwen3_5MoeExperts)
param_name: "gate_up_proj" or "down_proj"
active_experts: [num_active] sorted unique expert indices
Returns:
Compact weight tensor [num_active, dim1, dim2] ready for ScatterMoE
"""
# Check if the parameter is BnB-quantized via parametrize
if (
hasattr(experts_module, "parametrizations")
and param_name in experts_module.parametrizations
):
param_list = experts_module.parametrizations[param_name]
parametrization = param_list[0]
# BnB 4-bit parametrization
if hasattr(parametrization, "quant_state"):
# The raw quantized data is on the ParametrizationList, not the
# individual Bnb4bitParametrization module
raw_param = param_list.original
qs = parametrization.quant_state
# qs.shape is the original tensor shape before flattening.
# For MoE experts it's [E, d1, d2] (3D) or [total_elements] (1D).
orig_shape = qs.shape
if isinstance(orig_shape, torch.Size) and len(orig_shape) == 3:
expert_shape = (orig_shape[1], orig_shape[2])
elif isinstance(orig_shape, torch.Size) and len(orig_shape) == 1:
# Flattened — need to infer from module attributes
E_total = getattr(experts_module, "num_experts", None)
if E_total is None:
E_total = int(active_experts.max().item()) + 1
expert_numel = orig_shape[0] // E_total
d2 = getattr(experts_module, "hidden_dim", None) or getattr(
experts_module, "intermediate_dim", None
)
if d2 and expert_numel % d2 == 0:
expert_shape = (expert_numel // d2, d2)
else:
full = getattr(experts_module, param_name)
return full[active_experts]
else:
full = getattr(experts_module, param_name)
return full[active_experts]
return _selective_dequant_bnb4(raw_param, qs, active_experts, expert_shape)
# Dense parameter (bf16/fp32) — direct indexing
param = getattr(experts_module, param_name)
if param.dim() == 3:
return param[active_experts]
# Fallback: full access
return param
def selective_lora_weights(
lora_A: torch.Tensor,
lora_B: torch.Tensor,
active_experts: torch.Tensor,
E: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Select LoRA A and B weights for only the active experts.
LoRA layout (scattermoe format):
A: [r*E, K] — expert e occupies rows [e*r : (e+1)*r]
B: [N, r*E] — expert e occupies cols [e*r : (e+1)*r]
Returns compact:
A: [r*num_active, K]
B: [N, r*num_active]
"""
R = lora_A.size(0) // E
# Vectorized gather: active_experts[:, None] * R + arange(R)[None, :]
row_idx = (
active_experts.long()[:, None] * R
+ torch.arange(R, device=lora_A.device)[None, :]
).reshape(-1)
compact_A = lora_A[row_idx] # [r*num_active, K]
compact_B = lora_B[:, row_idx] # [N, r*num_active]
return compact_A, compact_B

View File

@@ -0,0 +1,179 @@
"""
Triton kernel for fused selective expert gather + NF4 dequantization.
Instead of:
1. Gather packed uint8 data for active experts (memory copy)
2. Gather absmax for active experts (memory copy)
3. Call BnB dequantize_4bit CUDA kernel
This kernel does all three in one pass:
- Reads packed NF4 bytes from expert-strided positions
- Looks up the NF4 codebook
- Multiplies by the per-block absmax
- Writes bf16 output directly
This eliminates the intermediate gather buffer entirely.
"""
import torch
import triton
import triton.language as tl
# NF4 codebook (16 values, precomputed by BnB)
# These are the normalized float4 reconstruction values
NF4_CODEBOOK = [
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0,
]
@triton.jit
def _selective_dequant_nf4_kernel(
# Input: packed NF4 data (flattened, expert-major order)
packed_ptr,
# Input: absmax values (flattened, expert-major order)
absmax_ptr,
# Input: active expert indices
active_experts_ptr,
# Input: NF4 codebook (16 float values)
codebook_ptr,
# Output: dequantized bf16 weights [num_active, expert_numel]
out_ptr,
stride_out_e, # stride for expert dim in output
# Dimensions
num_active,
packed_per_expert, # expert_numel // 2
blocks_per_expert, # expert_numel // blocksize
blocksize: tl.constexpr,
# Tile size
BLOCK_SIZE: tl.constexpr, # elements per thread block (must be multiple of 2)
):
"""
Each program processes BLOCK_SIZE elements from one expert.
Grid: (num_active, cdiv(expert_numel, BLOCK_SIZE))
For each output element:
1. Compute which byte in packed data contains this element
2. Extract the 4-bit nibble (high or low)
3. Look up in NF4 codebook
4. Scale by absmax for this block
"""
expert_local_idx = tl.program_id(0) # which active expert (0..num_active-1)
block_id = tl.program_id(1) # which element block
# Load the global expert index
expert_global = tl.load(active_experts_ptr + expert_local_idx).to(tl.int64)
expert_numel = packed_per_expert * 2 # 2 elements per packed byte
elem_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = elem_offset < expert_numel
# Each element is packed as: byte[i//2], low nibble for even i, high for odd i
byte_idx = elem_offset // 2
is_high = (elem_offset % 2) == 1
# Read packed bytes from the global expert's region
packed_global_offset = expert_global * packed_per_expert + byte_idx
packed_bytes = tl.load(packed_ptr + packed_global_offset, mask=mask, other=0).to(
tl.int32
)
# Extract 4-bit nibble
# BnB packing: high nibble = even element, low nibble = odd element
nibble = tl.where(is_high, packed_bytes & 0xF, (packed_bytes >> 4) & 0xF)
# NF4 codebook lookup
# Load all 16 codebook values (small, fits in registers)
# Use gather from codebook pointer
code_val = tl.load(codebook_ptr + nibble, mask=mask, other=0.0)
# Load absmax for this element's quantization block
block_idx = elem_offset // blocksize
absmax_global_offset = expert_global * blocks_per_expert + block_idx
absmax_val = tl.load(absmax_ptr + absmax_global_offset, mask=mask, other=1.0)
# Dequantize: value = codebook[nibble] * absmax
result = code_val * absmax_val
# Store to output
out_offset = expert_local_idx * stride_out_e + elem_offset
tl.store(out_ptr + out_offset, result.to(out_ptr.dtype.element_ty), mask=mask)
def selective_dequant_nf4_triton(
packed_data: torch.Tensor,
absmax: torch.Tensor,
active_experts: torch.Tensor,
expert_shape: tuple[int, int],
blocksize: int,
dtype: torch.dtype = torch.bfloat16,
codebook: torch.Tensor | None = None,
) -> torch.Tensor:
"""Fused selective gather + NF4 dequantization via Triton kernel.
Args:
packed_data: Flattened packed NF4 data [total_packed] or [total_packed, 1]
absmax: Per-block scaling factors [total_blocks]
active_experts: Sorted indices of experts to dequantize [num_active]
expert_shape: (dim1, dim2) per expert
blocksize: Quantization block size
dtype: Output dtype (default bf16)
codebook: NF4 lookup table [16] (uses default NF4 codebook if None)
Returns:
Dequantized weights [num_active, dim1, dim2]
"""
num_active = active_experts.shape[0]
expert_numel = expert_shape[0] * expert_shape[1]
packed_per_expert = expert_numel // 2
blocks_per_expert = expert_numel // blocksize
# Prepare codebook on device
if codebook is None:
codebook = torch.tensor(
NF4_CODEBOOK, dtype=torch.float32, device=packed_data.device
)
else:
codebook = codebook.to(device=packed_data.device, dtype=torch.float32)
# Flatten inputs
packed_flat = packed_data.reshape(-1)
absmax_flat = absmax.reshape(-1).float() # absmax is usually fp32
# Output buffer
out = torch.empty(num_active, expert_numel, dtype=dtype, device=packed_data.device)
BLOCK_SIZE = 1024 # Process 1024 elements per thread block
grid = (num_active, triton.cdiv(expert_numel, BLOCK_SIZE))
_selective_dequant_nf4_kernel[grid](
packed_flat,
absmax_flat,
active_experts,
codebook,
out,
out.stride(0),
num_active=num_active,
packed_per_expert=packed_per_expert,
blocks_per_expert=blocks_per_expert,
blocksize=blocksize,
BLOCK_SIZE=BLOCK_SIZE,
)
return out.reshape(num_active, *expert_shape)

View File

@@ -61,9 +61,20 @@ class KernelsPlugin(BasePlugin):
return "axolotl.integrations.kernels.KernelsArgs"
def pre_model_load(self, cfg):
from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK
# Prefer text backbone type for VLMs, but fall back to base type
# when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text)
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
if (
moe_model_type not in SPARSE_MOE_BLOCK
and cfg.model_config_type in SPARSE_MOE_BLOCK
):
moe_model_type = cfg.model_config_type
if cfg.use_scattermoe:
self._register_kernels()
self._kernelize_model(cfg.model_config_type)
self._kernelize_model(moe_model_type)
elif cfg.use_sonicmoe:
if not importlib.util.find_spec("sonicmoe"):
raise RuntimeError(
@@ -75,11 +86,9 @@ class KernelsPlugin(BasePlugin):
from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe
LOG.info(
f"Applying SonicMoE patches for model type: {cfg.model_config_type}"
)
LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}")
patch_sonicmoe(
cfg.model_config_type,
moe_model_type,
torch_compile=bool(getattr(cfg, "torch_compile", False)),
)
@@ -110,6 +119,16 @@ class KernelsPlugin(BasePlugin):
}
)
def add_callbacks_pre_trainer(self, cfg, model):
callbacks = []
if cfg.use_scattermoe:
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
callbacks.append(AutotuneReportCallback())
return callbacks
def _kernelize_model(self, model_type: str):
from kernels import replace_kernel_forward_from_hub

View File

@@ -5,6 +5,7 @@ Different MoE architectures use different routing strategies:
- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization)
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None)
- glm_moe_dsa: sigmoid -> topk (with group-based expert selection)
- mistral4: softmax -> group selection -> topk (with renormalization and scaling)
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
@@ -45,6 +46,8 @@ def get_model_moe_config(model_type: str):
"minimax",
):
return softmax_topk_routing, ActivationType.SWIGLU, "gate"
elif model_type in ("mistral4",):
return softmax_group_topk_routing, ActivationType.SWIGLU, "gate"
elif model_type in (
"glm_moe_dsa",
"deepseek_v3",
@@ -126,6 +129,62 @@ def softmax_topk_routing(
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
def softmax_group_topk_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale."""
gate = moe_block.gate
T, H = hidden_states.shape
K = moe_block.top_k
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
n_group = getattr(moe_block, "n_group", 1)
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
scores_for_choice = router_probs
# Group selection: pick top groups, mask the rest
if n_group > 1:
group_scores = (
scores_for_choice.view(-1, n_group, E // n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
)
group_idx = torch.topk(
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
topk_weights = router_probs.gather(1, topk_indices)
# Renormalization + scaling
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
if norm_topk_prob:
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
topk_weights = topk_weights * routed_scaling_factor
# Flatten for moe_general_routing_inputs
token_indices = (
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
.unsqueeze(1)
.expand(T, K)
)
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
flat_token_idx = token_indices.reshape(-1) # [T*K]
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
def sigmoid_topk_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

View File

@@ -25,7 +25,7 @@ def get_lora_parameters(
) -> tuple[
torch.Tensor,
torch.Tensor | None,
QuantState | None,
QuantState | torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
float | None,
@@ -48,9 +48,13 @@ def get_lora_parameters(
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
quant_state = getattr(W, "quant_state", None)
if quant_state is None and W.dtype == torch.float8_e4m3fn:
quant_state = getattr(base_layer, "weight_scale_inv", None)
return W, b, quant_state, None, None, None
quant_state = getattr(W, "quant_state", None)
if quant_state is None and W.dtype == torch.float8_e4m3fn:
quant_state = getattr(base_layer, "weight_scale_inv", None)
active_adapter = (
proj.active_adapters[0]
@@ -81,7 +85,7 @@ def matmul_lora(
X: torch.Tensor,
W: torch.Tensor,
b: torch.Tensor | None,
W_quant: QuantState | None,
W_quant: QuantState | torch.Tensor | None,
A: torch.Tensor | None,
B: torch.Tensor | None,
s: float | None,
@@ -636,7 +640,9 @@ class LoRA_QKV(torch.autograd.Function):
del q_weight
del q_weight_t
if A_q is not None and B_q is not None:
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
# Stay decomposed: dQ @ B^T gives [T, R], then [T, R] @ (s*A) gives [T, in]
# This is 65x fewer FLOPs than materializing B@A into [out, in]
grad_X.addmm_(torch.mm(q_grad, B_q_scaled), A_q_scaled)
# K path
k_weight_t = dequantize(k_weight, k_quant)
@@ -644,7 +650,7 @@ class LoRA_QKV(torch.autograd.Function):
del k_weight
del k_weight_t
if A_k is not None and B_k is not None:
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
grad_X.addmm_(torch.mm(k_grad, B_k_scaled), A_k_scaled)
# V path
v_weight_t = dequantize(v_weight, v_quant)
@@ -652,7 +658,7 @@ class LoRA_QKV(torch.autograd.Function):
del v_weight
del v_weight_t
if A_v is not None and B_v is not None:
grad_X.addmm_(v_grad, torch.mm(B_v_scaled, A_v_scaled))
grad_X.addmm_(torch.mm(v_grad, B_v_scaled), A_v_scaled)
# Transpose gradients if needed
if d_A_q is not None:
@@ -815,7 +821,8 @@ class LoRA_O(torch.autograd.Function):
del W
A, B = A.to(dtype), B.to(dtype)
dX += s * dY @ B @ A
# Stay decomposed: dY @ B gives [T, R], then [T, R] @ A gives [T, in]
dX.addmm_(torch.mm(dY, B), A, alpha=s)
# W, b, W_quant, A, B, s
return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None

View File

@@ -1,4 +1,4 @@
"""Dequantization utilities for `bitsandbytes` integration."""
"""Dequantization utilities for `bitsandbytes` and FP8 integration."""
import ctypes
@@ -15,9 +15,50 @@ CUDA_STREAM: torch.cuda.Stream | None = None
HAS_CUDA_STREAM: bool = Version(bnb.__version__) > Version("0.43.3")
def dequantize_fp8(
W: torch.Tensor,
scale_inv: torch.Tensor,
dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
"""Dequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv.
Args:
W: FP8 weight tensor [out_features, in_features] in float8_e4m3fn.
scale_inv: Per-block inverse scale [ceil(out/block), ceil(in/block)]
or per-tensor scalar.
dtype: Output dtype (default bf16).
Returns:
Dequantized tensor in the specified dtype.
"""
W_float = W.to(dtype)
if scale_inv.numel() == 1:
return W_float * scale_inv.to(dtype)
if scale_inv.dim() == 2 and W.dim() == 2:
sr, sc = scale_inv.shape
br = W.shape[0] // sr
bc = W.shape[1] // sc
# If dimensions are exactly divisible, use fast reshape path
if sr * br == W.shape[0] and sc * bc == W.shape[1]:
return (
W_float.reshape(sr, br, sc, bc) * scale_inv[:, None, :, None].to(dtype)
).reshape(W.shape)
# Tail-block handling: compute actual block size (ceil division),
# tile scale_inv to cover full shape, then crop to W's dimensions
br_ceil = -(-W.shape[0] // sr) # ceil(rows / scale_rows) = block_size
bc_ceil = -(-W.shape[1] // sc)
scale_expanded = (
scale_inv.to(dtype)
.repeat_interleave(br_ceil, dim=0)
.repeat_interleave(bc_ceil, dim=1)
)[: W.shape[0], : W.shape[1]]
return W_float * scale_expanded
return W_float * scale_inv.to(dtype)
def dequantize(
W: torch.Tensor,
quant_state: QuantState | list | None = None,
quant_state: QuantState | list | torch.Tensor | None = None,
out: torch.Tensor | None = None,
) -> torch.Tensor:
"""
@@ -49,6 +90,15 @@ def dequantize(
if quant_state is None:
return W
# FP8 path: quant_state is actually scale_inv tensor
if W.dtype == torch.float8_e4m3fn:
scale_inv = quant_state
# Caller may pass W.t() (non-contiguous) — dequantize in original
# layout then transpose back so the result shape matches the input.
if not W.is_contiguous() and W.dim() == 2:
return dequantize_fp8(W.t(), scale_inv).t()
return dequantize_fp8(W, scale_inv)
# Get the target device from input tensor W
target_device = W.device

View File

@@ -160,6 +160,18 @@ def load_lora(
else:
model = get_peft_model(model, lora_config, **model_kwargs)
# FP8 models: LoRA A/B inherit FP8 dtype from base weights, but training
# requires a compute dtype (bf16/fp16). Cast trainable LoRA params.
if cfg.torch_dtype:
_fp8_cast_dtype = cfg.torch_dtype
elif torch.cuda.is_available() and torch.cuda.is_bf16_supported():
_fp8_cast_dtype = torch.bfloat16
else:
_fp8_cast_dtype = torch.float16
for _name, param in model.named_parameters():
if param.requires_grad and param.dtype == torch.float8_e4m3fn:
param.data = param.data.to(_fp8_cast_dtype)
if rank == 0:
try:
model.print_trainable_parameters()

View File

@@ -215,6 +215,8 @@ class ModelLoader:
self.model_kwargs["revision"] = self.cfg.revision_of_model
if self.cfg.use_kernels:
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
if "allow_all_kernels" not in self.model_kwargs:
self.model_kwargs["allow_all_kernels"] = self.cfg.use_kernels
self._set_quantization_config()
self._set_attention_config()
self._check_model_requirements()
@@ -503,6 +505,20 @@ class ModelLoader:
elif not is_ds_zero3:
self.model_kwargs["device_map"] = device_map
# quantize_moe_experts quantizes expert weights on-the-fly during loading,
# so the actual VRAM usage is much less than bf16 estimates.
# When device_map is "auto", accelerate's infer_auto_device_map computes
# the device map at bf16 size (before quantization), causing it to offload
# layers to CPU, which BnB then rejects. Force single-GPU placement to
# prevent this. Only applies to the non-FSDP, non-ZeRO3 path (DDP/single).
if getattr(self.cfg, "quantize_moe_experts", False) and device_map in (
"auto",
None,
):
self.model_kwargs["device_map"] = {
"": int(os.environ.get("LOCAL_RANK", 0))
}
cur_device = get_device_type()
if "mps" in str(cur_device):
self.model_kwargs["device_map"] = "mps:0"
@@ -829,8 +845,9 @@ class ModelLoader:
def _set_z3_leaf_modules(self):
from deepspeed.utils import set_z3_leaf_modules
if self.cfg.model_config_type in MOE_ARCH_BLOCK:
moe_blocks = MOE_ARCH_BLOCK[self.cfg.model_config_type]
moe_type = self.cfg.model_config_type_text or self.cfg.model_config_type
if moe_type in MOE_ARCH_BLOCK:
moe_blocks = MOE_ARCH_BLOCK[moe_type]
moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks
set_z3_leaf_modules(
self.model,

View File

@@ -93,11 +93,13 @@ class PatchManager:
def apply_pre_model_load_patches(self):
"""Apply pre-model load patches based on config."""
self._deactivate_hf_async_load()
self._apply_transformers_patches()
# self._apply_flex_attention_patches()
self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch()
self._apply_sageattn_patches()
self._apply_flash_attn_4_patches()
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_model_specific_patches()
@@ -114,6 +116,8 @@ class PatchManager:
self._apply_patch_deepspeed_zero3()
self._apply_voxtral_patches()
self._apply_apertus_patches()
self._apply_trl_vllm_patches()
self._apply_trl_trainer_utils_patches()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
@@ -129,13 +133,6 @@ class PatchManager:
patch_evaluation_loop()
patch_maybe_log_save_evaluate()
if self.cfg.context_parallel_size > 1:
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
patch_prepare_context_parallel_inputs,
)
patch_prepare_context_parallel_inputs()
def apply_post_model_build_patches(self, model: PreTrainedModel):
"""Apply patches right after model build, before post-load setup."""
self._finalize_moe_expert_quantization(model)
@@ -227,6 +224,15 @@ class PatchManager:
patch_sageattn()
def _apply_flash_attn_4_patches(self):
"""Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+."""
if not self.cfg.flash_attention:
return
from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4
patch_flash_attn_4(self.model_config)
def _apply_model_specific_patches(self):
"""Apply patches specific to model architectures."""
if (
@@ -409,17 +415,27 @@ class PatchManager:
if self.cfg.load_in_8bit:
apply_linear8bitlt_save_patch()
def _deactivate_hf_async_load(self):
"""Load weights synchronously so they can be converted and not OOM."""
if self.cfg.load_in_4bit or self.cfg.load_in_8bit:
os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1"
def _apply_moe_expert_quantization_patch(self):
"""Patch transformers weight loading to quantize MoE expert params on-the-fly."""
if not self.cfg.quantize_moe_experts:
"""Patch transformers weight loading and PEFT for MoE expert quantization."""
has_target_params = bool(getattr(self.cfg, "lora_target_parameters", None))
if not self.cfg.quantize_moe_experts and not has_target_params:
return
from axolotl.monkeypatch.moe_quant import (
patch_moe_quantization_on_load,
patch_peft_target_parameters_matching,
)
patch_moe_quantization_on_load(self.cfg)
if self.cfg.quantize_moe_experts:
from axolotl.monkeypatch.moe_quant import patch_moe_quantization_on_load
patch_moe_quantization_on_load(self.cfg)
patch_peft_target_parameters_matching()
def _finalize_moe_expert_quantization(self, model: PreTrainedModel):
@@ -646,6 +662,50 @@ class PatchManager:
patch_apertus_xielu_activation()
def _apply_trl_vllm_patches(self):
"""Apply TRL vLLM patches for batched weight sync, NaN logprobs fix, and scalar handling."""
if (
self.cfg.rl
and getattr(self.cfg, "trl", None)
and getattr(self.cfg.trl, "use_vllm", False)
):
from axolotl.monkeypatch.trainer.trl_vllm import patch_trl_vllm
patch_trl_vllm()
def _apply_trl_trainer_utils_patches(self):
"""Replace trl.trainer.utils.{selective_log_softmax, entropy_from_logits} with Triton kernels."""
if not self.cfg.rl:
return
try:
from axolotl.monkeypatch.trainer.utils import (
entropy_from_logits,
selective_log_softmax,
)
except (ImportError, ModuleNotFoundError):
LOG.warning("Triton not available — skipping trl.trainer.utils patches")
return
import trl.trainer.utils
# Guard against repeated calls: only stash the original if trl still
# points at its own implementation (not our wrapper).
if trl.trainer.utils.selective_log_softmax is not selective_log_softmax:
from axolotl.monkeypatch.trainer import utils as _axolotl_trainer_utils
_axolotl_trainer_utils.selective_log_softmax_original = (
trl.trainer.utils.selective_log_softmax
)
trl.trainer.utils.selective_log_softmax = selective_log_softmax
if trl.trainer.utils.entropy_from_logits is not entropy_from_logits:
trl.trainer.utils.entropy_from_logits = entropy_from_logits
LOG.info(
"Patched trl.trainer.utils with Triton selective_log_softmax and entropy_from_logits"
)
def _apply_scaling_softmax_patch(self, model: PreTrainedModel):
"""Apply Scaling Softmax (SSMax) patch. Ref: https://arxiv.org/abs/2501.19399"""
if self.cfg.scaling_softmax:

View File

@@ -55,12 +55,12 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
)
processor_kwargs["trust_remote_code"] = cfg.trust_remote_code or False
processor_kwargs["tokenizer"] = tokenizer
processor = processor_cls.from_pretrained(
cfg.processor_config,
**processor_kwargs,
)
processor.tokenizer = tokenizer
# Attempt to load image size from processor if available
if (

View File

@@ -78,30 +78,29 @@ def patch_parallelism_config():
def patch_prepare_cp():
import functools
import contextlib
import torch
from accelerate import Accelerator
from transformers import Trainer
def patched_prepare_cp(self, *args):
if self.parallelism_config.cp_backend == "deepspeed":
return args
from accelerate.big_modeling import _attach_context_parallel_hooks
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
set_rotate_method(cp_comm_strategy)
self._cp_context = functools.partial(
context_parallel, mesh=self.torch_device_mesh["cp"]
)
for arg in args:
if isinstance(arg, torch.nn.Module):
_attach_context_parallel_hooks(arg)
@contextlib.contextmanager
def _noop_cp_context(
buffers=None, buffer_seq_dims=None, no_restore_buffers=None
):
yield
self._cp_context = _noop_cp_context
return args
def _noop_prepare_context_parallel_inputs(self, model, inputs):
return contextlib.nullcontext, inputs
# prevent double CP partition
Accelerator._prepare_cp = patched_prepare_cp
# remove unneeded calculation upstream
Trainer._prepare_context_parallel_inputs = _noop_prepare_context_parallel_inputs

View File

@@ -0,0 +1,104 @@
"""Transparently upgrade FA2 to FA4 when available on SM90+ hardware."""
import torch
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _get_head_dims(model_config):
"""Extract (head_dim, head_dim_v) from a model config.
Handles composite models (e.g. Qwen3.5 VL) via text_config and
MLA models (DeepSeek/Kimi) that have separate Q/V head dimensions.
"""
cfg = model_config
if hasattr(cfg, "text_config"):
cfg = cfg.text_config
# MLA models: Q head_dim = qk_nope + qk_rope, V head_dim = v_head_dim
if hasattr(cfg, "qk_nope_head_dim") and hasattr(cfg, "qk_rope_head_dim"):
head_dim = cfg.qk_nope_head_dim + cfg.qk_rope_head_dim
head_dim_v = getattr(cfg, "v_head_dim", head_dim)
return head_dim, head_dim_v
# Standard models
if hasattr(cfg, "head_dim"):
return cfg.head_dim, cfg.head_dim
if hasattr(cfg, "hidden_size") and hasattr(cfg, "num_attention_heads"):
head_dim = cfg.hidden_size // cfg.num_attention_heads
return head_dim, head_dim
return None, None
def patch_flash_attn_4(model_config=None):
"""Patch _lazy_imports to redirect FA2 imports to FA4 if available on supported hardware."""
if not torch.cuda.is_available():
return
major, _ = torch.cuda.get_device_capability()
# Matches flash_attn/cute/interface.py: arch / 10 in [9, 10, 11]
if major not in (9, 10, 11):
return
try:
from flash_attn.cute import ( # noqa: F401
flash_attn_func,
flash_attn_varlen_func,
)
except ImportError:
LOG.info(
"Flash Attention 4 is available for your GPU and offers faster training speeds. "
"To enable: pip install flash-attn-4"
)
return
# Validate head dimensions against FA4's own constraints
head_dim = None
if model_config is not None:
head_dim, head_dim_v = _get_head_dims(model_config)
if head_dim is not None:
try:
from flash_attn.cute.interface import _validate_head_dims
except ImportError:
LOG.warning(
"Could not import _validate_head_dims from flash_attn.cute.interface, "
"unable to verify head dimension compatibility, falling back to FA2"
)
return
# alignment = 16 // element_size; bf16/fp16 = 2 bytes -> alignment = 8
alignment = 8
try:
_validate_head_dims(head_dim, head_dim_v, major, alignment)
except AssertionError as exc:
LOG.warning(
"Model head dimensions not supported by FA4, "
"falling back to FA2: %s",
exc,
)
return
import transformers.modeling_flash_attention_utils as fa_utils
if getattr(fa_utils._lazy_imports, "_axolotl_patched", False):
return
def _patched_lazy_imports(
implementation, attention_wrapper=None, allow_all_kernels=False
):
return (
flash_attn_func,
flash_attn_varlen_func,
fa_utils._pad_input,
fa_utils._unpad_input,
)
_patched_lazy_imports._axolotl_patched = True
fa_utils._lazy_imports = _patched_lazy_imports
LOG.info(
"Flash Attention 4 enabled (head_dim=%s)",
head_dim if model_config else "unknown",
)

View File

@@ -64,15 +64,12 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
LOG.info(
"Compiling flex attention with kwargs: %s. This may take a while...",
flex_attn_compile_kwargs,
main_process_only=True,
)
self._compiled_flex_attention = torch.compile(
flex_attention,
**flex_attn_compile_kwargs,
)
LOG.info(
"Flex attention compiled successfully.", main_process_only=True
)
LOG.info("Flex attention compiled successfully.")
self._is_flex_compiled = True

View File

@@ -51,6 +51,29 @@ QKV_PATCHES = [
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip("\n"),
),
(
"""
query_states, gate = torch.chunk(
self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
)
gate = gate.reshape(*input_shape, -1)
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip("\n"),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states, gate = torch.chunk(
query_states.view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
)
gate = gate.reshape(*input_shape, -1)
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip("\n"),
),
]
ORIGINAL_O_CODE = """
@@ -299,6 +322,8 @@ def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]:
if hasattr(pretrained_model, "language_model"):
return pretrained_model.language_model.layers
if hasattr(pretrained_model, "model"):
if hasattr(pretrained_model.model, "language_model"):
return pretrained_model.model.language_model.layers
return pretrained_model.model.layers
raise NotImplementedError(

View File

@@ -1,11 +1,4 @@
"""
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
In transformers v5, MoE models store expert weights as fused 3D tensors that BnB
skips (only targets nn.Linear). This module patches weight loading to quantize them
on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametrization),
reducing peak VRAM from "all experts in bf16" to "one expert at a time."
"""
"""Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors."""
import bitsandbytes as bnb
import torch
@@ -15,18 +8,20 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# Module-level state for the loading-time quantization patch.
_moe_load_state = {
"count": 0,
"mode": "4bit",
"quant_type": "nf4",
"compress_statistics": True,
"patched": False,
# Module path → param names in definition order, captured before quantization.
# Without this, alphabetical loading order would mismatch merge order.
"expert_param_order": {},
}
class Bnb8bitParametrization(torch.nn.Module):
"""Parametrization that dequantizes int8 row-wise quantized data on access."""
"""Dequantizes int8 row-wise quantized data on access."""
def __init__(self, row_stats: torch.Tensor):
super().__init__()
@@ -34,7 +29,7 @@ class Bnb8bitParametrization(torch.nn.Module):
@torch.no_grad()
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
# Flatten 3D+ to 2D for BnB's dequant, then reshape back.
"""Flatten 3D+ to 2D for BnB's dequant, then reshape back."""
orig_shape = quantized_param.shape
if quantized_param.ndim > 2:
quantized_param = quantized_param.reshape(-1, orig_shape[-1])
@@ -74,14 +69,11 @@ def replace_parameter_8bit(module, param_name):
def patch_moe_quantization_on_load(cfg):
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly.
Wraps ``set_param_for_module`` so that 3D+ CUDA tensors with "expert" in their
name are quantized (4-bit or 8-bit) as they're loaded, keeping peak VRAM low.
"""
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly."""
mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit"
_moe_load_state["mode"] = mode
_moe_load_state["count"] = 0
_moe_load_state["expert_param_order"] = {}
if _moe_load_state["patched"]:
LOG.debug("MoE loading-time quantization patch already active")
@@ -113,7 +105,6 @@ def patch_moe_quantization_on_load(cfg):
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
original_set_param(model, target_name, param_value, *args, **kwargs)
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
if param_value.ndim >= 3 and param_value.is_cuda:
mod_path, _, pname = target_name.rpartition(".")
mod = model.get_submodule(mod_path) if mod_path else model
@@ -126,6 +117,13 @@ def patch_moe_quantization_on_load(cfg):
)
return
# Record definition order before parametrizations override it
# with alphabetical order.
if mod_path not in _moe_load_state["expert_param_order"]:
_moe_load_state["expert_param_order"][mod_path] = list(
mod._parameters.keys()
)
if _moe_load_state["mode"] == "4bit":
replace_parameter_4bit(
mod,
@@ -151,20 +149,28 @@ def get_moe_quantized_count():
def patch_peft_target_parameters_matching():
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules."""
"""Fix PEFT's _inject_parameters for target_parameters on quantized MoE experts.
1. Expands short suffixes to full module paths for parametrized modules.
2. Iterates params in definition order (not alphabetical order) so saved
adapters are compatible with standard PEFT, vLLM, etc.
"""
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
return
from peft.tuners.tuners_utils import BaseTuner
original_inject = BaseTuner._inject_parameters
from contextlib import nullcontext
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
from peft.utils.integrations import init_empty_weights
from peft.utils.other import _get_submodules
def _patched_inject_parameters(
self, peft_config, model, adapter_name, low_cpu_mem_usage
):
# Patch target_parameters to use full paths for parametrized modules
original_targets = list(peft_config.target_parameters)
expanded = set(original_targets)
# Expand short suffixes to full paths for parametrized modules.
for module_name, module in model.named_modules():
if not hasattr(module, "parametrizations"):
continue
@@ -175,14 +181,74 @@ def patch_peft_target_parameters_matching():
) and hasattr(module, param_name):
expanded.add(f"{module_name}.{param_name}")
peft_config.target_parameters = sorted(expanded)
try:
return original_inject(
self, peft_config, model, adapter_name, low_cpu_mem_usage
)
finally:
peft_config.target_parameters = original_targets
target_names_set = expanded
def strip_base_layer_from_name(module_name):
name = ".base_layer"
while name in module_name:
prefix, _, suffix = module_name.rpartition(name)
module_name = prefix + suffix
return module_name
def create_and_replace_param(module_name, key, param_name):
parent, target, target_name = _get_submodules(model, module_name)
unwrapped_module_name = strip_base_layer_from_name(module_name)
unwrapped_module = model.get_submodule(unwrapped_module_name)
if (
isinstance(unwrapped_module, BaseTunerLayer)
and unwrapped_module.__class__.__name__ != "ParamWrapper"
):
raise ValueError(
f"Trying to wrap an `nn.Parameter` of layer "
f"'{unwrapped_module_name}' of type "
f"{type(target).__name__}, which is not a valid target. "
f"Make sure that this layer is not also targeted with "
f"`target_modules`."
)
self._check_target_module_compatiblity(peft_config, model, target_name)
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
self._create_and_replace(
peft_config,
adapter_name,
target,
target_name,
parent,
current_key=key,
parameter_name=param_name.rpartition(".")[-1],
)
# Use definition order (not alphabetical order) for parametrized modules
# so ParamWrapper nesting matches vanilla PEFT on a plain model.
expert_param_order = _moe_load_state.get("expert_param_order", {})
for module_name, module in model.named_modules():
if hasattr(module, "parametrizations"):
stored_order = expert_param_order.get(module_name)
if stored_order is not None:
params_iter = [
p for p in stored_order if p in module.parametrizations
]
else:
# Fallback for paths that bypass model loading (e.g. unit tests).
params_iter = list(module.parametrizations.keys())
for param_name in params_iter:
key = f"{module_name}.{param_name}"
if (key in target_names_set) or any(
key.endswith(f".{t}") for t in target_names_set
):
create_and_replace_param(module_name, key, param_name)
self.targeted_parameter_names.append(key)
else:
unwrapped_module_name = strip_base_layer_from_name(module_name)
for param_name, _ in module.named_parameters(recurse=False):
key = f"{unwrapped_module_name}.{param_name}"
if (key in target_names_set) or any(
key.endswith(f".{t}") for t in target_names_set
):
create_and_replace_param(module_name, key, param_name)
self.targeted_parameter_names.append(key)
BaseTuner._inject_parameters = _patched_inject_parameters
patch_peft_target_parameters_matching._axolotl_patched = True
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")
LOG.info("Patched PEFT _inject_parameters for consistent ParamWrapper ordering")

View File

@@ -57,7 +57,9 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"olmo3",
"ministral",
"ministral3",
"mistral4",
"afmoe",
"nemotron",
]

View File

@@ -154,7 +154,6 @@ def register_ring_attn_from_device_mesh(
LOG.info(
f"Enabling ring attention sequence parallelism using DeviceMesh "
f"dimension '{context_parallel_dim}'",
main_process_only=True,
)
# Extract the sequence parallel submesh

View File

@@ -85,7 +85,6 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
mlp_cls._tiled_mlp_dist_impl = None
LOG.info(
f"Successfully monkey-patched TiledMLP for model_type: {model_type}",
main_process_only=True,
)
except (ImportError, AttributeError) as e:
raise RuntimeError(

View File

@@ -0,0 +1,3 @@
from .utils import entropy_from_logits, selective_log_softmax
__all__ = ["entropy_from_logits", "selective_log_softmax"]

View File

@@ -0,0 +1,245 @@
"""Monkeypatches for TRL's vLLM integration and trainer utils.
Adds:
- VLLMClient.batch_update_named_params: batched weight sync (fewer HTTP round-trips)
- extract_logprobs: NaN→0.0 fix (prevents downstream NaN propagation)
- VLLMGeneration: weight_sync_chunk_size + batched sync path for non-FSDP/non-ZeRO
- split_tensor_dict / shuffle_sequence_dict: scalar type handling (int/float/bool passthrough)
"""
import logging
import math
from functools import wraps
import torch
from torch import nn
LOG = logging.getLogger(__name__)
def _batch_update_named_params(
self, params: list[tuple[str, torch.Tensor]], chunk_size: int | None = None
):
"""Batched weight sync — sends param metadata via HTTP, tensors via NCCL."""
from transformers import is_torch_xpu_available
if chunk_size is None:
chunks = [params]
else:
chunks = []
current_chunk: list[tuple[str, torch.Tensor]] = []
current_elements = 0
for name, weights in params:
n_elem = weights.numel()
if current_chunk and current_elements + n_elem > chunk_size:
chunks.append(current_chunk)
current_chunk = []
current_elements = 0
current_chunk.append((name, weights))
current_elements += n_elem
if current_chunk:
chunks.append(current_chunk)
for chunk in chunks:
param_metadata = [
{"name": name, "dtype": str(weights.dtype), "shape": list(weights.shape)}
for name, weights in chunk
]
url = f"{self.base_url}/batch_update_named_params/"
response = self.session.post(url, json={"params": param_metadata})
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
for _name, weights in chunk:
if is_torch_xpu_available():
self.communicator.broadcast(weights, root=self.rank)
else:
self.communicator.broadcast(weights, src=self.rank)
if is_torch_xpu_available():
self.communicator.barrier()
else:
self.communicator.group.barrier()
def _update_model_params(self, model: nn.Module, chunk_size: int | None = None):
"""Updates all model params using batch_update_named_params."""
params = [(name, param.data) for name, param in model.named_parameters()]
self.batch_update_named_params(params, chunk_size=chunk_size)
def _patched_extract_logprobs(all_outputs):
"""extract_logprobs with NaN→0.0 fix (stock TRL uses None which causes downstream errors)."""
all_logprobs = []
all_token_ids = []
for outputs in all_outputs:
for output in outputs.outputs:
if output.logprobs is None:
return None, None
seq_logprobs = []
seq_token_ids = []
for lp in output.logprobs:
sorted_items = sorted(lp.items(), key=lambda x: x[1].rank)
seq_token_ids.append([token_id for token_id, _ in sorted_items])
seq_logprobs.append(
[
0.0 if math.isnan(item.logprob) else item.logprob
for _, item in sorted_items
]
)
all_logprobs.append(seq_logprobs)
all_token_ids.append(seq_token_ids)
return all_logprobs, all_token_ids
def _patched_split_tensor_dict(tensor_dict, num_chunks):
"""split_tensor_dict that handles scalar types (int/float/bool) for num_items_in_batch."""
first_tensor = next(
tensor
for tensor in tensor_dict.values()
if tensor is not None and isinstance(tensor, torch.Tensor) and tensor.ndim > 0
)
chunk_size = first_tensor.shape[0] // num_chunks
chunks = []
for i in range(num_chunks):
chunk_dict = {}
for key, tensor in tensor_dict.items():
if isinstance(tensor, (int, float, bool)):
chunk_dict[key] = tensor
elif tensor is not None and (isinstance(tensor, list) or tensor.ndim > 0):
chunk_dict[key] = tensor[i * chunk_size : (i + 1) * chunk_size]
elif tensor is not None and tensor.ndim == 0:
chunk_dict[key] = tensor
else:
chunk_dict[key] = None
chunks.append(chunk_dict)
return chunks
def _patched_shuffle_sequence_dict(seq_dict):
"""shuffle_sequence_dict that handles scalar types (int/float/bool)."""
first_seq = next(
v
for v in seq_dict.values()
if v is not None and isinstance(v, (torch.Tensor, list)) and len(v) > 0
)
perm = torch.randperm(len(first_seq))
def permute(v):
if v is None:
return None
if isinstance(v, (int, float, bool)):
return v
if isinstance(v, torch.Tensor) and v.ndim == 0:
return v
if isinstance(v, torch.Tensor) and v.ndim >= 1:
return v[perm]
if isinstance(v, list):
return [v[i] for i in perm.tolist()]
return v
return {k: permute(v) for k, v in seq_dict.items()}
def _patch_sync_weights_batched(original_init):
"""Wrap VLLMGeneration.__init__ to accept weight_sync_chunk_size."""
@wraps(original_init)
def patched_init(self, *args, weight_sync_chunk_size=None, **kwargs):
original_init(self, *args, **kwargs)
self.weight_sync_chunk_size = weight_sync_chunk_size
return patched_init
def _make_batched_sync_weights(original_sync_weights):
"""Wrap sync_weights to use batched sync for non-FSDP/non-ZeRO paths."""
@wraps(original_sync_weights)
def patched_sync_weights(self):
from accelerate.utils import is_peft_model
# Check if we're in a non-PEFT, non-FSDP, non-ZeRO scenario where batching helps
accelerator = self.accelerator
model = self.model
is_fsdp_enabled = self.is_fsdp_enabled
deepspeed_plugin = accelerator.state.deepspeed_plugin
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
is_peft = is_peft_model(model)
# If PEFT, FSDP, or ZeRO-3, fall back to original (which handles those cases)
if is_peft or is_fsdp_enabled or zero_stage_3:
return original_sync_weights(self)
# Non-PEFT, non-FSDP, non-ZeRO: use batched sync
if self.mode == "colocate" and getattr(self, "enable_sleep_mode", False):
from vllm.distributed.device_communicators.cuda_wrapper import (
empty_cache,
)
empty_cache()
self.llm.wake_up(tags=["weights"])
if self.mode == "server" and accelerator.is_main_process:
params = [
(self._fix_param_name_to_vllm(name), param.data)
for name, param in model.named_parameters()
]
self.vllm_client.batch_update_named_params(
params, chunk_size=getattr(self, "weight_sync_chunk_size", None)
)
elif self.mode == "colocate":
llm_model = (
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
)
weights = [
(self._fix_param_name_to_vllm(name), param.data)
for name, param in model.named_parameters()
]
llm_model.load_weights(weights=weights)
# Reset cache
if self.mode == "server" and accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()
elif self.mode == "colocate":
self.llm.reset_prefix_cache()
return patched_sync_weights
def patch_trl_vllm():
"""Apply all TRL vLLM monkeypatches."""
import trl.generation.vllm_client
import trl.generation.vllm_generation
import trl.trainer.utils
VLLMClient = trl.generation.vllm_client.VLLMClient
VLLMGeneration = trl.generation.vllm_generation.VLLMGeneration
# 1. Add batch_update_named_params to VLLMClient
if not hasattr(VLLMClient, "batch_update_named_params"):
VLLMClient.batch_update_named_params = _batch_update_named_params
VLLMClient.update_model_params = _update_model_params
LOG.info("Patched VLLMClient with batch_update_named_params")
# 2. Patch extract_logprobs (NaN→0.0)
trl.generation.vllm_generation.extract_logprobs = _patched_extract_logprobs
LOG.info("Patched extract_logprobs with NaN→0.0 fix")
# 3. Patch VLLMGeneration.__init__ to accept weight_sync_chunk_size
VLLMGeneration.__init__ = _patch_sync_weights_batched(VLLMGeneration.__init__)
# 4. Patch sync_weights for batched non-FSDP/non-ZeRO path
VLLMGeneration.sync_weights = _make_batched_sync_weights(
VLLMGeneration.sync_weights
)
LOG.info("Patched VLLMGeneration with batched sync_weights")
# 5. Patch split_tensor_dict and shuffle_sequence_dict
trl.trainer.utils.split_tensor_dict = _patched_split_tensor_dict
trl.trainer.utils.shuffle_sequence_dict = _patched_shuffle_sequence_dict
LOG.info("Patched split_tensor_dict and shuffle_sequence_dict for scalar types")

View File

@@ -0,0 +1,429 @@
# Copyright 2026 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
@triton.jit
def _entropy_online_kernel(
logits_ptr,
output_ptr,
stride_row,
V: tl.constexpr,
BLOCK_V: tl.constexpr,
):
"""Online entropy: single pass with running max correction."""
row = tl.program_id(0)
row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_row
running_max = tl.full([], float("-inf"), dtype=tl.float32)
running_sum_exp = tl.full([], 0.0, dtype=tl.float32)
running_weighted = tl.full([], 0.0, dtype=tl.float32)
for v_start in range(0, V, BLOCK_V):
offs = v_start + tl.arange(0, BLOCK_V)
mask = offs < V
x = tl.load(row_ptr + offs, mask=mask, other=float("-inf")).to(tl.float32)
block_max = tl.max(x, axis=0)
new_max = tl.maximum(running_max, block_max)
correction = tl.exp(running_max - new_max)
running_sum_exp = running_sum_exp * correction
running_weighted = running_weighted * correction
exp_x = tl.exp(x - new_max)
exp_x = tl.where(mask, exp_x, 0.0)
x = tl.where(mask, x, 0.0)
running_sum_exp += tl.sum(exp_x, axis=0)
running_weighted += tl.sum(exp_x * x, axis=0)
running_max = new_max
entropy = tl.log(running_sum_exp) + running_max - running_weighted / running_sum_exp
tl.store(output_ptr + row, entropy)
@triton.jit
def _entropy_online_kernel_strided(
logits_ptr,
output_ptr,
stride_outer,
stride_inner,
n_inner,
row_offset,
V: tl.constexpr,
BLOCK_V: tl.constexpr,
):
"""Online entropy for non-contiguous 3D (B, L, V) tensors."""
local_row = tl.program_id(0)
row = local_row + row_offset
outer_idx = row // n_inner
inner_idx = row % n_inner
off = outer_idx.to(tl.int64) * stride_outer + inner_idx.to(tl.int64) * stride_inner
row_ptr = logits_ptr + off
running_max = tl.full([], float("-inf"), dtype=tl.float32)
running_sum_exp = tl.full([], 0.0, dtype=tl.float32)
running_weighted = tl.full([], 0.0, dtype=tl.float32)
for v_start in range(0, V, BLOCK_V):
offs = v_start + tl.arange(0, BLOCK_V)
mask = offs < V
x = tl.load(row_ptr + offs, mask=mask, other=float("-inf")).to(tl.float32)
block_max = tl.max(x, axis=0)
new_max = tl.maximum(running_max, block_max)
correction = tl.exp(running_max - new_max)
running_sum_exp = running_sum_exp * correction
running_weighted = running_weighted * correction
exp_x = tl.exp(x - new_max)
exp_x = tl.where(mask, exp_x, 0.0)
x = tl.where(mask, x, 0.0)
running_sum_exp += tl.sum(exp_x, axis=0)
running_weighted += tl.sum(exp_x * x, axis=0)
running_max = new_max
entropy = tl.log(running_sum_exp) + running_max - running_weighted / running_sum_exp
tl.store(output_ptr + local_row, entropy)
def entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> torch.Tensor:
"""Triton-fused entropy (online single-pass). Handles non-contiguous tensors without copying."""
original_shape = logits.shape[:-1]
V = logits.shape[-1]
N = 1
for s in original_shape:
N *= s
if not logits.is_cuda:
# CPU fallback: stable entropy via log_softmax
logp = F.log_softmax(logits.float(), dim=-1)
ent = -(logp.exp() * logp).sum(dim=-1)
return ent.to(logits.dtype).reshape(original_shape)
output = torch.empty(N, device=logits.device, dtype=torch.float32)
BLOCK_V = 4096
MAX_GRID_CONTIG = 8192
MAX_GRID_STRIDED = 2048
# Vocab (last) dim must be contiguous for coalesced loads
if logits.stride(-1) != 1:
logits = logits.contiguous()
if logits.is_contiguous():
flat_logits = logits.reshape(-1, V)
stride = flat_logits.stride(0)
for start in range(0, N, MAX_GRID_CONTIG):
n_rows = min(MAX_GRID_CONTIG, N - start)
_entropy_online_kernel[(n_rows,)](
flat_logits[start], output[start], stride, V=V, BLOCK_V=BLOCK_V
)
elif logits.ndim == 3:
stride_outer = logits.stride(0)
stride_inner = logits.stride(1)
n_inner = logits.shape[1]
for start in range(0, N, MAX_GRID_STRIDED):
n_rows = min(MAX_GRID_STRIDED, N - start)
_entropy_online_kernel_strided[(n_rows,)](
logits,
output[start],
stride_outer,
stride_inner,
n_inner,
start,
V=V,
BLOCK_V=BLOCK_V,
)
else:
logits = logits.contiguous()
flat_logits = logits.reshape(-1, V)
stride = flat_logits.stride(0)
for start in range(0, N, MAX_GRID_CONTIG):
n_rows = min(MAX_GRID_CONTIG, N - start)
_entropy_online_kernel[(n_rows,)](
flat_logits[start], output[start], stride, V=V, BLOCK_V=BLOCK_V
)
return output.to(logits.dtype).reshape(original_shape)
# ---------------------------------------------------------------------------
# selective_log_softmax — fused forward + backward Triton kernels
# ---------------------------------------------------------------------------
def selective_log_softmax_original(logits, index) -> torch.Tensor:
"""Original selective_log_softmax (reference/fallback)."""
squeeze = index.ndim == logits.ndim - 1
if squeeze:
index = index.unsqueeze(-1)
if logits.dtype in [torch.float32, torch.float64]:
selected_logits = torch.gather(logits, dim=-1, index=index)
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
per_token_logps = selected_logits - logsumexp_values.unsqueeze(-1)
else:
per_token_logps = []
for row_logits, row_labels in zip(logits, index, strict=True):
row_logps = F.log_softmax(row_logits, dim=-1)
row_per_token_logps = row_logps.gather(dim=-1, index=row_labels)
per_token_logps.append(row_per_token_logps)
per_token_logps = torch.stack(per_token_logps)
if squeeze:
per_token_logps = per_token_logps.squeeze(-1)
return per_token_logps
@triton.jit
def _selective_logsoftmax_fwd_kernel(
logits_ptr,
index_ptr,
output_ptr,
logsumexp_ptr,
stride_logits_row,
stride_index_row,
stride_output_row,
actual_K,
K_BLOCK: tl.constexpr,
V: tl.constexpr,
BLOCK_V: tl.constexpr,
):
"""Forward: online logsumexp + gather. Saves logsumexp for backward."""
row = tl.program_id(0)
logits_row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_logits_row
# Online logsumexp
running_max = tl.full([], float("-inf"), dtype=tl.float32)
running_sum_exp = tl.full([], 0.0, dtype=tl.float32)
for v_start in range(0, V, BLOCK_V):
offs = v_start + tl.arange(0, BLOCK_V)
mask = offs < V
x = tl.load(logits_row_ptr + offs, mask=mask, other=float("-inf")).to(
tl.float32
)
block_max = tl.max(x, axis=0)
new_max = tl.maximum(running_max, block_max)
running_sum_exp = running_sum_exp * tl.exp(running_max - new_max)
exp_x = tl.exp(x - new_max)
exp_x = tl.where(mask, exp_x, 0.0)
running_sum_exp += tl.sum(exp_x, axis=0)
running_max = new_max
lse = tl.log(running_sum_exp) + running_max
tl.store(logsumexp_ptr + row, lse)
# Gather and subtract
index_row_ptr = index_ptr + tl.cast(row, tl.int64) * stride_index_row
output_row_ptr = output_ptr + tl.cast(row, tl.int64) * stride_output_row
k_offs = tl.arange(0, K_BLOCK)
k_mask = k_offs < actual_K
indices = tl.load(index_row_ptr + k_offs, mask=k_mask, other=0).to(tl.int64)
valid_mask = k_mask & (indices >= 0) & (indices < V)
safe_indices = tl.where(valid_mask, indices, 0)
selected = tl.load(logits_row_ptr + safe_indices, mask=valid_mask, other=0.0).to(
tl.float32
)
tl.store(output_row_ptr + k_offs, selected - lse, mask=valid_mask)
@triton.jit
def _selective_logsoftmax_bwd_kernel(
grad_output_ptr,
logits_ptr,
index_ptr,
logsumexp_ptr,
grad_logits_ptr,
stride_grad_out_row,
stride_logits_row,
stride_index_row,
stride_grad_logits_row,
actual_K,
K_BLOCK: tl.constexpr,
V: tl.constexpr,
BLOCK_V: tl.constexpr,
):
"""Backward: d_logits[j] = -softmax(x)[j] * sum(grad_out) + (grad_out[k] if j == index[k]).
Single fused pass over V. For each tile, computes the base gradient and adds
scatter contributions inline by checking which indices fall in the current tile.
No separate scatter pass — no read-after-write issues.
"""
row = tl.program_id(0)
logits_row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_logits_row
grad_logits_row_ptr = (
grad_logits_ptr + tl.cast(row, tl.int64) * stride_grad_logits_row
)
grad_out_row_ptr = grad_output_ptr + tl.cast(row, tl.int64) * stride_grad_out_row
index_row_ptr = index_ptr + tl.cast(row, tl.int64) * stride_index_row
lse = tl.load(logsumexp_ptr + row).to(tl.float32)
# Load grad_output and indices (K_BLOCK elements, masked)
k_offs = tl.arange(0, K_BLOCK)
k_mask = k_offs < actual_K
grad_out = tl.load(grad_out_row_ptr + k_offs, mask=k_mask, other=0.0).to(tl.float32)
indices = tl.load(
index_row_ptr + k_offs, mask=k_mask, other=-1
) # -1 = never matches
valid_mask = k_mask & (indices >= 0) & (indices < V)
grad_out = tl.where(valid_mask, grad_out, 0.0)
indices = tl.where(valid_mask, indices, -1)
grad_sum = tl.sum(grad_out, axis=0)
# Fused pass: for each tile, compute -softmax * grad_sum + scatter
for v_start in range(0, V, BLOCK_V):
offs = v_start + tl.arange(0, BLOCK_V) # [BLOCK_V]
mask = offs < V
x = tl.load(logits_row_ptr + offs, mask=mask, other=0.0).to(tl.float32)
softmax_j = tl.exp(x - lse)
softmax_j = tl.where(mask, softmax_j, 0.0)
grad_j = -softmax_j * grad_sum
# Scatter: check which selected indices fall in this tile
# offs: [BLOCK_V], indices: [K_BLOCK]
# Broadcast: offs[:, None] == indices[None, :] → [BLOCK_V, K_BLOCK]
match = offs[:, None] == indices[None, :] # [BLOCK_V, K_BLOCK]
# Sum grad_out contributions: for each position j, sum grad_out[k] where index[k]==j
scatter_contrib = tl.sum(
tl.where(match, grad_out[None, :], 0.0), axis=1
) # [BLOCK_V]
grad_j += scatter_contrib
tl.store(grad_logits_row_ptr + offs, grad_j, mask=mask)
class _SelectiveLogSoftmaxTriton(torch.autograd.Function):
@staticmethod
def forward(ctx, flat_logits, flat_index, K, K_BLOCK, V, BLOCK_V, MAX_GRID):
N = flat_logits.shape[0]
output = torch.empty(N, K_BLOCK, device=flat_logits.device, dtype=torch.float32)
logsumexp = torch.empty(N, device=flat_logits.device, dtype=torch.float32)
for start in range(0, N, MAX_GRID):
n_rows = min(MAX_GRID, N - start)
_selective_logsoftmax_fwd_kernel[(n_rows,)](
flat_logits[start],
flat_index[start],
output[start],
logsumexp[start],
flat_logits.stride(0),
flat_index.stride(0),
output.stride(0),
K,
K_BLOCK=K_BLOCK,
V=V,
BLOCK_V=BLOCK_V,
)
ctx.save_for_backward(flat_logits, flat_index, logsumexp)
ctx.K = K
ctx.K_BLOCK = K_BLOCK
ctx.V = V
ctx.BLOCK_V = BLOCK_V
ctx.MAX_GRID = MAX_GRID
return output
@staticmethod
def backward(ctx, grad_output):
flat_logits, flat_index, logsumexp = ctx.saved_tensors
K, K_BLOCK, V, BLOCK_V, MAX_GRID = (
ctx.K,
ctx.K_BLOCK,
ctx.V,
ctx.BLOCK_V,
ctx.MAX_GRID,
)
N = flat_logits.shape[0]
grad_logits = torch.empty_like(flat_logits)
# grad_output may have K_BLOCK cols; backward kernel reads actual_K
grad_output_contig = grad_output.contiguous()
for start in range(0, N, MAX_GRID):
n_rows = min(MAX_GRID, N - start)
_selective_logsoftmax_bwd_kernel[(n_rows,)](
grad_output_contig[start],
flat_logits[start],
flat_index[start],
logsumexp[start],
grad_logits[start],
grad_output_contig.stride(0),
flat_logits.stride(0),
flat_index.stride(0),
grad_logits.stride(0),
K,
K_BLOCK=K_BLOCK,
V=V,
BLOCK_V=BLOCK_V,
)
# Return grads for: flat_logits, flat_index, K, K_BLOCK, V, BLOCK_V, MAX_GRID
return grad_logits, None, None, None, None, None, None
def selective_log_softmax(logits, index) -> torch.Tensor:
"""
Fused selective_log_softmax with Triton forward+backward kernels.
Equivalent to: torch.gather(logits.log_softmax(-1), dim=-1, index=index)
"""
squeeze = index.ndim == logits.ndim - 1
if squeeze:
index = index.unsqueeze(-1)
if not logits.is_cuda or logits.dtype == torch.float64:
# Triton kernel computes in float32; fall back for float64 and CPU
return selective_log_softmax_original(
logits, index.squeeze(-1) if squeeze else index
)
V = logits.shape[-1]
K = index.shape[-1]
original_index_shape = index.shape
flat_logits = logits.reshape(-1, V).contiguous()
flat_index = index.reshape(-1, K).contiguous()
BLOCK_V = 4096
MAX_GRID = 8192
K_BLOCK = max(1, triton.next_power_of_2(K))
output = _SelectiveLogSoftmaxTriton.apply(
flat_logits, flat_index, K, K_BLOCK, V, BLOCK_V, MAX_GRID
)
if K_BLOCK != K:
output = output[:, :K]
per_token_logps = output.to(logits.dtype).reshape(original_index_shape)
if squeeze:
per_token_logps = per_token_logps.squeeze(-1)
return per_token_logps

View File

@@ -1,72 +0,0 @@
"""Monkey patch to allow context parallelism with FlashAttention in HF Trainer."""
from __future__ import annotations
import importlib
import inspect
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
PATCHED_GUARD = 'if (attn_impl := (getattr(model.config, "_attn_implementation", None) or getattr(model.model.config, "_attn_implementation", None))) and attn_impl not in ("sdpa", "flash_attention_2"):'
def patch_prepare_context_parallel_inputs() -> None:
"""Relax the SDPA-only guard when running context parallelism with FlashAttention."""
if getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False):
LOG.debug("Trainer._prepare_context_parallel_inputs already patched")
return
try:
original_source = inspect.getsource(Trainer._prepare_context_parallel_inputs)
except OSError as exc: # pragma: no cover - occurs when source is unavailable
LOG.warning("Unable to patch Trainer._prepare_context_parallel_inputs: %s", exc)
return
if GUARD_PATTERN not in original_source:
LOG.warning(
"Expected guard not found in Trainer._prepare_context_parallel_inputs; \n"
"skipping FlashAttention context parallelism patch"
)
return
patched_source = original_source.replace(GUARD_PATTERN, PATCHED_GUARD)
patched_source, _ = detab_code(patched_source)
patched_source = patched_source.replace(
"def _prepare_context_parallel_inputs(",
"def axolotl_prepare_context_parallel_inputs(",
1,
)
module_name = Trainer.__module__
module = importlib.import_module(module_name)
# import symbols referenced in the method so exec can succeed
items_to_import = []
for item in dir(module):
if item in patched_source:
items_to_import.append(item)
# Use a separate namespace to capture the exec'd function
namespace = {}
exec(f"from {module_name} import ({', '.join(items_to_import)})", namespace)
exec(patched_source, namespace)
# Explicitly get the function from the namespace
axolotl_prepare_context_parallel_inputs = namespace[
"axolotl_prepare_context_parallel_inputs"
]
Trainer._original_prepare_context_parallel_inputs = (
Trainer._prepare_context_parallel_inputs
)
Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs
Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source
Trainer._axolotl_prepare_context_parallel_inputs_patched = True
LOG.debug(
"Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP"
)

View File

View File

@@ -0,0 +1,503 @@
"""vLLM serve script with native LoRA adapter support.
Extends TRL's vllm_serve to enable direct LoRA adapter loading in vLLM,
instead of merging adapter weights into the base model before syncing.
Usage:
Set ``vllm.serve_module: axolotl.scripts.vllm_serve_lora`` in your config,
or ``trl.vllm_lora_sync: true`` to auto-select.
Benefits over merge-sync:
- Syncs only LoRA adapter weights via filesystem instead of full merged model via NCCL
- vLLM handles LoRA application natively (Punica kernels)
- No NCCL communicator needed for weight sync
"""
import logging
import os
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from itertools import chain
from multiprocessing import Pipe, Process
from multiprocessing.connection import Connection
from typing import Any
from trl.scripts.vllm_serve import (
ScriptArguments,
chunk_list,
extract_logprobs,
get_open_port,
)
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
logger = logging.getLogger(__name__)
@dataclass
class LoRAScriptArguments(ScriptArguments):
"""Extended script arguments with LoRA support."""
enable_lora: bool = field(
default=True,
metadata={"help": "Enable LoRA adapter support in vLLM."},
)
max_lora_rank: int = field(
default=64,
metadata={"help": "Maximum LoRA rank supported."},
)
max_loras: int = field(
default=2,
metadata={"help": "Maximum number of LoRA adapters loaded simultaneously."},
)
lora_dtype: str = field(
default="bfloat16",
metadata={"help": "Data type for LoRA weights."},
)
def llm_worker(
script_args: LoRAScriptArguments,
data_parallel_rank: int,
master_port: int,
connection: Connection,
) -> None:
"""Worker process that creates a vLLM LLM with LoRA enabled."""
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
llm = LLM(
model=script_args.model,
revision=script_args.revision,
tensor_parallel_size=script_args.tensor_parallel_size,
gpu_memory_utilization=script_args.gpu_memory_utilization,
enforce_eager=script_args.enforce_eager,
dtype=script_args.dtype,
enable_prefix_caching=script_args.enable_prefix_caching,
kv_cache_dtype=script_args.kv_cache_dtype,
max_model_len=script_args.max_model_len,
# Use batch-capable worker extension (adds batch_update_named_params + auto-close)
worker_extension_cls="axolotl.scripts.vllm_worker_ext.BatchWeightSyncWorkerExtension",
trust_remote_code=script_args.trust_remote_code,
model_impl=script_args.vllm_model_impl,
logprobs_mode="processed_logprobs",
# LoRA
enable_lora=script_args.enable_lora,
max_lora_rank=script_args.max_lora_rank,
max_loras=script_args.max_loras,
lora_dtype=script_args.lora_dtype,
)
connection.send({"status": "ready"})
while True:
try:
command = connection.recv()
except KeyboardInterrupt:
llm.collective_rpc(method="close_communicator")
break
if command["type"] in ["call", "fire_and_forget"]:
method_name = command["method"]
args = command.get("args", ())
kwargs = command.get("kwargs", {})
# Reconstruct LoRARequest from serialized dict (can't pickle across pipe)
if "lora_request" in kwargs and kwargs["lora_request"] is not None:
lr = kwargs["lora_request"]
kwargs["lora_request"] = LoRARequest(
lora_name=lr["lora_name"],
lora_int_id=lr["lora_int_id"],
lora_path=lr["lora_path"],
load_inplace=lr.get("load_inplace", False),
)
method = getattr(llm, method_name)
result = method(*args, **kwargs)
if command["type"] == "call":
connection.send(result)
elif command["type"] == "shutdown":
break
def main(script_args: ScriptArguments):
"""Start vLLM workers with LoRA support and the HTTP server."""
import asyncio
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel, Field as PydanticField
# Request/Response models (defined locally like TRL's vllm_serve.main)
class GenerateRequest(BaseModel):
prompts: list[str]
images: list[str] | None = None
n: int = 1
repetition_penalty: float = 1.0
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
min_p: float = 0.0
max_tokens: int = 16
logprobs: int | None = 0
truncate_prompt_tokens: int | None = None
structured_outputs_regex: str | None = None
generation_kwargs: dict = PydanticField(default_factory=dict)
class GenerateResponse(BaseModel):
prompt_ids: list[list[int]]
completion_ids: list[list[int]]
logprobs: list[list[list[float]]]
logprob_token_ids: list[list[list[int]]]
class ChatRequest(BaseModel):
messages: list[list[dict]]
n: int = 1
repetition_penalty: float = 1.0
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
min_p: float = 0.0
max_tokens: int = 16
logprobs: int | None = 0
truncate_prompt_tokens: int | None = None
structured_outputs_regex: str | None = None
generation_kwargs: dict = PydanticField(default_factory=dict)
chat_template_kwargs: dict = PydanticField(default_factory=dict)
class ChatResponse(BaseModel):
prompt_ids: list[list[int]]
completion_ids: list[list[int]]
logprobs: list[list[list[float]]]
logprob_token_ids: list[list[list[int]]]
class InitCommunicatorRequest(BaseModel):
host: str
port: int
world_size: int
client_device_uuid: str
# Wrap plain ScriptArguments with LoRA defaults
if not isinstance(script_args, LoRAScriptArguments):
lora_args = LoRAScriptArguments.__new__(LoRAScriptArguments)
for f in ScriptArguments.__dataclass_fields__:
setattr(lora_args, f, getattr(script_args, f))
# Apply LoRA defaults
for f in LoRAScriptArguments.__dataclass_fields__:
if f not in ScriptArguments.__dataclass_fields__:
setattr(
lora_args, f, LoRAScriptArguments.__dataclass_fields__[f].default
)
script_args = lora_args
# Spawn workers
master_port = get_open_port()
connections: list[Connection] = []
processes: list[Process] = []
for dp_rank in range(script_args.data_parallel_size):
parent_conn, child_conn = Pipe()
process = Process(
target=llm_worker,
args=(script_args, dp_rank, master_port, child_conn),
)
process.start()
connections.append(parent_conn)
processes.append(process)
@asynccontextmanager
async def lifespan(app: FastAPI):
import time
startup_timeout = 300 # 5 minutes
start_time = time.monotonic()
ready: set[int] = set()
while len(ready) < script_args.data_parallel_size:
elapsed = time.monotonic() - start_time
if elapsed > startup_timeout:
raise RuntimeError(
f"vLLM workers failed to start within {startup_timeout}s "
f"({len(ready)}/{script_args.data_parallel_size} ready)"
)
for i, (conn, proc) in enumerate(zip(connections, processes, strict=True)):
if id(conn) in ready:
continue
if not proc.is_alive():
raise RuntimeError(
f"vLLM worker {i} exited unexpectedly during startup"
)
if conn.poll():
msg = conn.recv()
if isinstance(msg, dict) and msg.get("status") == "ready":
ready.add(id(conn))
await asyncio.sleep(0.1)
yield
for p in processes:
p.join(timeout=10)
if p.is_alive():
p.terminate()
p.join()
app = FastAPI(lifespan=lifespan)
# --- Active LoRA state (shared across endpoints via closure) ---
active_lora: dict = {"request": None}
# ------------------------------------------------------------------
# LoRA-specific endpoints
# ------------------------------------------------------------------
class SetLoRARequest(BaseModel):
lora_name: str
lora_int_id: int
lora_path: str
load_inplace: bool = False
@app.post("/set_lora_adapter/")
async def set_lora_adapter(request: SetLoRARequest):
"""Register a LoRA adapter for all subsequent generate/chat calls."""
active_lora["request"] = {
"lora_name": request.lora_name,
"lora_int_id": request.lora_int_id,
"lora_path": request.lora_path,
"load_inplace": request.load_inplace,
}
logger.info(
"Set active LoRA: %s (id=%d, path=%s)",
request.lora_name,
request.lora_int_id,
request.lora_path,
)
return {"status": "ok"}
@app.post("/clear_lora_adapter/")
async def clear_lora_adapter():
"""Clear active LoRA adapter (revert to base model)."""
active_lora["request"] = None
return {"status": "ok"}
# ------------------------------------------------------------------
# Standard endpoints (mirrors TRL's vllm_serve)
# ------------------------------------------------------------------
@app.get("/health/")
async def health():
return {"status": "ok"}
@app.get("/get_world_size/")
async def get_world_size():
return {
"world_size": script_args.tensor_parallel_size
* script_args.data_parallel_size
}
@app.post("/generate/", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
"""Generate completions with optional LoRA adapter."""
import base64
from io import BytesIO
import vllm
from packaging.version import Version
from vllm.sampling_params import GuidedDecodingParams
images: list[str | None] = request.images or [None] * len(request.prompts) # type: ignore[assignment,list-item]
prompts: list[dict[str, Any]] = []
for prompt, image in zip(request.prompts, images, strict=True):
row: dict[str, Any] = {"prompt": prompt}
if image is not None:
from PIL import Image
row["multi_modal_data"] = {
"image": Image.open(BytesIO(base64.b64decode(image)))
}
prompts.append(row)
generation_kwargs = {
"n": request.n,
"repetition_penalty": request.repetition_penalty,
"temperature": request.temperature,
"top_p": request.top_p,
"top_k": request.top_k,
"min_p": request.min_p,
"max_tokens": request.max_tokens,
"logprobs": request.logprobs,
}
generation_kwargs.update(request.generation_kwargs)
if Version(vllm.__version__) <= Version("0.10.2"):
key = "guided_decoding"
if request.structured_outputs_regex is not None:
generation_kwargs[key] = GuidedDecodingParams(
regex=request.structured_outputs_regex
)
else:
generation_kwargs.setdefault(key, None)
else:
from vllm.sampling_params import StructuredOutputsParams
key = "structured_outputs"
if request.structured_outputs_regex is not None:
generation_kwargs[key] = StructuredOutputsParams(
regex=request.structured_outputs_regex
)
elif isinstance(generation_kwargs.get(key), dict):
generation_kwargs[key] = StructuredOutputsParams(
**generation_kwargs[key]
)
else:
generation_kwargs.setdefault(key, None)
sampling_params = SamplingParams(**generation_kwargs)
chunked_prompts = chunk_list(prompts, script_args.data_parallel_size)
for conn, chunk in zip(connections, chunked_prompts, strict=True):
if not chunk:
chunk = [{"prompt": "<placeholder>"}]
kwargs = {
"prompts": chunk,
"sampling_params": sampling_params,
"lora_request": active_lora["request"],
}
conn.send({"type": "call", "method": "generate", "kwargs": kwargs})
all_outputs = [conn.recv() for conn in connections]
all_outputs = [
o for o, c in zip(all_outputs, chunked_prompts, strict=True) if c
]
all_outputs = list(chain.from_iterable(all_outputs))
return {
"prompt_ids": [o.prompt_token_ids for o in all_outputs],
"completion_ids": [
list(out.token_ids) for o in all_outputs for out in o.outputs
],
"logprobs": extract_logprobs(all_outputs)[0],
"logprob_token_ids": extract_logprobs(all_outputs)[1],
}
@app.post("/chat/", response_model=ChatResponse)
async def chat(request: ChatRequest):
"""Chat endpoint with optional LoRA adapter."""
generation_kwargs = {
"n": request.n,
"repetition_penalty": request.repetition_penalty,
"temperature": request.temperature,
"top_p": request.top_p,
"top_k": request.top_k,
"min_p": request.min_p,
"max_tokens": request.max_tokens,
"logprobs": request.logprobs,
}
generation_kwargs.update(request.generation_kwargs)
sampling_params = SamplingParams(**generation_kwargs)
chunked = chunk_list(request.messages, script_args.data_parallel_size)
for conn, chunk in zip(connections, chunked, strict=True):
if not chunk:
chunk = [[{"role": "user", "content": "<placeholder>"}]]
kwargs = {
"messages": chunk,
"sampling_params": sampling_params,
"use_tqdm": False,
"lora_request": active_lora["request"],
}
conn.send({"type": "call", "method": "chat", "kwargs": kwargs})
all_outputs = [conn.recv() for conn in connections]
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
all_outputs = list(chain.from_iterable(all_outputs))
return {
"prompt_ids": [o.prompt_token_ids for o in all_outputs],
"completion_ids": [
list(out.token_ids) for o in all_outputs for out in o.outputs
],
"logprobs": extract_logprobs(all_outputs)[0],
"logprob_token_ids": extract_logprobs(all_outputs)[1],
}
# --- Weight sync endpoints (legacy fallback, same as TRL) ---
@app.post("/init_communicator/")
async def init_communicator(request: InitCommunicatorRequest):
world_size = (
script_args.tensor_parallel_size * script_args.data_parallel_size + 1
)
kwargs = {
"method": "init_communicator",
"args": (
request.host,
request.port,
world_size,
request.client_device_uuid,
),
}
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
loop = asyncio.get_running_loop()
await asyncio.gather(
*(loop.run_in_executor(None, c.send, msg) for c in connections)
)
return {"message": "Initializing communicator"}
class UpdateWeightsRequest(BaseModel):
name: str
dtype: str
shape: list[int]
@app.post("/update_named_param/")
async def update_named_param(request: UpdateWeightsRequest):
kwargs = {
"method": "update_named_param",
"args": (request.name, request.dtype, tuple(request.shape)),
}
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
loop = asyncio.get_running_loop()
await asyncio.gather(
*(loop.run_in_executor(None, c.send, msg) for c in connections)
)
return {"message": "Updating parameter"}
class BatchUpdateWeightsRequest(BaseModel):
params: list[dict]
@app.post("/batch_update_named_params/")
async def batch_update_named_params(request: BatchUpdateWeightsRequest):
params_list = [
(p["name"], p["dtype"], tuple(p["shape"])) for p in request.params
]
kwargs = {"method": "batch_update_named_params", "args": (params_list,)}
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
loop = asyncio.get_running_loop()
await asyncio.gather(
*(loop.run_in_executor(None, c.send, msg) for c in connections)
)
return {"message": f"Batch update for {len(params_list)} params"}
@app.post("/reset_prefix_cache/")
async def reset_prefix_cache():
for conn in connections:
conn.send({"type": "call", "method": "reset_prefix_cache"})
results = [conn.recv() for conn in connections]
return {"message": f"Reset prefix cache: {all(results)}"}
@app.post("/close_communicator/")
async def close_communicator():
kwargs = {"method": "close_communicator"}
for conn in connections:
conn.send(
{
"type": "fire_and_forget",
"method": "collective_rpc",
"kwargs": kwargs,
}
)
return {"message": "Closing communicator"}
uvicorn.run(
app,
host=script_args.host,
port=script_args.port,
log_level=script_args.log_level,
access_log=True,
)

View File

@@ -0,0 +1,158 @@
"""Extended vLLM worker extension with batch weight sync support.
Subclasses TRL's WeightSyncWorkerExtension to add:
- batch_update_named_params: receives multiple params in one call
- Auto-close stale communicator on re-init
- _direct_set_weight: proper handling for stacked (qkv_proj, gate_up_proj) params,
including LoRA-wrapped models where vLLM inserts base_layer into the hierarchy
"""
import logging
import torch
try:
from transformers import is_torch_xpu_available
except ImportError:
is_torch_xpu_available = lambda: False # noqa: E731
from trl.scripts.vllm_serve import WeightSyncWorkerExtension
logger = logging.getLogger(__name__)
# Stacked param name mapping: shard_name -> (packed_name, shard_order)
_STACKED_PARAMS = {
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
class BatchWeightSyncWorkerExtension(WeightSyncWorkerExtension):
"""Worker extension that adds batch weight update and direct weight setting."""
def init_communicator(self, host, port, world_size, client_device_uuid):
"""Auto-close stale communicator before re-initializing."""
if self.communicator is not None:
self.close_communicator()
super().init_communicator(host, port, world_size, client_device_uuid)
def _direct_set_weight(self, name: str, weight: torch.Tensor) -> None:
"""Directly copy weight data into the model, handling stacked params.
Bypasses model.load_weights() which may fail on vLLM 0.17's new
module-tree weight loader for stacked params (qkv_proj, gate_up_proj).
Handles LoRA-wrapped params where vLLM inserts ``base_layer`` into the
parameter hierarchy (e.g. ``qkv_proj.base_layer.weight``).
"""
model = self.model_runner.model
params_dict = dict(model.named_parameters())
# Check if this is a simple direct param (exists as-is)
if name in params_dict:
params_dict[name].data.copy_(weight.to(params_dict[name].dtype))
return
# Also check with base_layer inserted: x.y.weight -> x.y.base_layer.weight
parts_bl = name.rsplit(".", 1)
if len(parts_bl) == 2:
base_layer_name = f"{parts_bl[0]}.base_layer.{parts_bl[1]}"
if base_layer_name in params_dict:
params_dict[base_layer_name].data.copy_(
weight.to(params_dict[base_layer_name].dtype)
)
return
# Handle stacked params: e.g. "model.layers.0.self_attn.q_proj.weight"
# -> "model.layers.0.self_attn.qkv_proj.weight" with shard offset
parts = name.rsplit(".", 2) # [prefix, layer_name, suffix]
if len(parts) == 3:
prefix, layer_name, suffix = parts
if layer_name in _STACKED_PARAMS:
packed_name, shard_idx = _STACKED_PARAMS[layer_name]
for packed_full in [
f"{prefix}.{packed_name}.{suffix}",
f"{prefix}.{packed_name}.base_layer.{suffix}",
]:
if packed_full not in params_dict:
continue
param = params_dict[packed_full]
# Navigate to the packed module to find shard sizes
module_path = packed_full.rsplit(".", 1)[0] # strip .weight/.bias
if ".base_layer" in module_path:
module_path = module_path.replace(".base_layer", "")
module = model
for attr in module_path.split("."):
module = getattr(module, attr, None)
if module is None:
break
# LoRA wrappers don't have output_sizes directly;
# check base_layer for the underlying parallel linear
if module is not None and not hasattr(module, "output_sizes"):
base = getattr(module, "base_layer", None)
if base is not None and hasattr(base, "output_sizes"):
module = base
if module is not None and hasattr(module, "output_sizes"):
tp_size = getattr(module, "tp_size", 1)
sizes = [s // tp_size for s in module.output_sizes]
offset = sum(sizes[:shard_idx])
shard_size = sizes[shard_idx]
param.data[offset : offset + shard_size].copy_(
weight.to(param.dtype)
)
return
# Fallback: try load_weights (may work for non-stacked params)
logger.warning("Falling back to load_weights for param: %s", name)
model.load_weights(weights=[(name, weight)])
def update_named_param(self, name, dtype, shape):
"""Override to use _direct_set_weight instead of load_weights."""
if self.communicator is None:
raise RuntimeError("Communicator not initialized.")
dtype = getattr(torch, dtype.split(".")[-1])
weight = torch.empty(shape, dtype=dtype, device=self.device)
if is_torch_xpu_available():
self.communicator.broadcast(weight, root=self.client_rank)
self.communicator.barrier()
else:
self.communicator.broadcast(weight, src=self.client_rank)
self.communicator.group.barrier()
self._direct_set_weight(name, weight)
def batch_update_named_params(self, params_list: list[tuple[str, str, tuple]]):
"""Receive and apply multiple weight tensors in sequence.
Args:
params_list: List of (name, dtype_str, shape) tuples.
"""
if self.communicator is None:
raise RuntimeError("Communicator not initialized.")
weights_to_load = []
for name, dtype_str, shape in params_list:
dtype = getattr(torch, dtype_str.split(".")[-1])
weight = torch.empty(shape, dtype=dtype, device=self.device)
if is_torch_xpu_available():
self.communicator.broadcast(weight, root=self.client_rank)
else:
self.communicator.broadcast(weight, src=self.client_rank)
weights_to_load.append((name, weight))
# Single barrier after all broadcasts
if is_torch_xpu_available():
self.communicator.barrier()
else:
self.communicator.group.barrier()
# Load weights using direct set (handles stacked params)
for name, weight in weights_to_load:
self._direct_set_weight(name, weight)

View File

@@ -69,7 +69,6 @@ def setup_model_and_tokenizer(
# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)

View File

@@ -59,7 +59,6 @@ class DynamicCheckpointCallback(TrainerCallback):
f"Dynamic checkpoint enabled. To trigger checkpoint save:\n"
f" • File: touch {cfg.output_dir}/{self.trigger_filename}\n"
f" • Check interval: every {self.check_interval} steps",
main_process_only=True,
)
def on_step_end(
@@ -89,12 +88,10 @@ class DynamicCheckpointCallback(TrainerCallback):
LOG.info(
f"Dynamic checkpoint triggered via file '{self.trigger_filename}' "
f"at step {state.global_step}",
main_process_only=True,
)
except OSError as exc:
LOG.warning(
f"Failed to delete trigger file: {exc}",
main_process_only=True,
)
if self.should_save_checkpoint:
@@ -127,6 +124,5 @@ class DynamicCheckpointCallback(TrainerCallback):
control.should_save = True
LOG.info(
f"Saving dynamic checkpoint at step {state.global_step}",
main_process_only=True,
)
return control

View File

@@ -17,6 +17,8 @@ from transformers import (
class PytorchProfilerCallback(TrainerCallback):
"""
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
Also runs torch.profiler to produce a Chrome trace for timing analysis.
"""
def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0):
@@ -26,9 +28,10 @@ class PytorchProfilerCallback(TrainerCallback):
if profiler_steps_start == 0:
# start recording memory allocations before everything is allocated, because if we start
# at the beginning of step 0, we won't have any memory allocations in the traces
torch.cuda.memory._record_memory_history(enabled="all")
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
profiler_steps_start = -1
self.profiler_steps_start = profiler_steps_start
self._profiler = None
def on_step_begin(
self,
@@ -38,7 +41,21 @@ class PytorchProfilerCallback(TrainerCallback):
**kwargs,
):
if state.global_step == self.profiler_steps_start:
torch.cuda.memory._record_memory_history(enabled="all")
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
# Start torch.profiler on the first profiled step
if state.global_step == max(self.profiler_steps_start, 0):
profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
profile_memory=True,
with_stack=True,
)
profiler.__enter__()
self._profiler = profiler
def on_step_end(
self,
@@ -55,6 +72,13 @@ class PytorchProfilerCallback(TrainerCallback):
# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)
# Stop and export torch.profiler trace
if self._profiler is not None:
self._profiler.__exit__(None, None, None)
trace_path = Path(args.output_dir) / "profiler_trace.json"
self._profiler.export_chrome_trace(str(trace_path))
self._profiler = None
def on_train_end(
self,
args: TrainingArguments,
@@ -73,3 +97,9 @@ class PytorchProfilerCallback(TrainerCallback):
# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)
if self._profiler is not None:
self._profiler.__exit__(None, None, None)
trace_path = Path(args.output_dir) / "profiler_trace.json"
self._profiler.export_chrome_trace(str(trace_path))
self._profiler = None

View File

@@ -84,7 +84,7 @@ def resolve_dtype(cfg):
cfg.fp16 = True
cfg.bf16 = False
else:
if cfg.tf32:
if cfg.tf32 is True:
torch.set_float32_matmul_precision("high")
if is_torch_greater_or_equal("2.9.0"):
torch.backends.fp32_precision = "tf32"
@@ -195,6 +195,15 @@ def normalize_config(cfg):
cfg.model_config_type = model_config.model_type
# Resolve inner text backbone type for VLM wrappers (e.g. mistral3 -> mistral4)
if callable(getattr(model_config, "get_text_config", None)):
text_config = model_config.get_text_config()
if (
hasattr(text_config, "model_type")
and text_config.model_type != model_config.model_type
):
cfg.model_config_type_text = text_config.model_type
# figure out if the model is llama
cfg.is_llama_derived_model = (
(

View File

@@ -348,7 +348,9 @@ def _load_raw_datasets(
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
else:
dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
if cfg.sample_packing:
if (split == "train" and cfg.sample_packing) or (
split == "test" and cfg.eval_sample_packing
):
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
# Deduplicate before saving so the saved dataset is already de-duplicated

View File

@@ -474,13 +474,11 @@ def load_preprocessed_dataset(cfg: DictDefault, dataset_hash: str) -> Dataset |
):
LOG.info(
f"Loading prepared dataset from disk at {prepared_ds_path}...",
main_process_only=True,
)
return load_from_disk(str(prepared_ds_path))
LOG.info(
f"Unable to find prepared dataset in {prepared_ds_path}",
main_process_only=True,
)
return None

View File

@@ -29,7 +29,7 @@ if version.parse(torch.__version__) >= version.parse("2.8.0"):
from torchao.prototype.mx_formats import NVFP4InferenceConfig
quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4"
except:
except (ImportError, RuntimeError):
pass
# int4 weight config imports will fail on machines with fbgemm-gpu installed
@@ -38,7 +38,7 @@ if version.parse(torch.__version__) >= version.parse("2.8.0"):
from torchao.quantization.quant_api import Int4WeightOnlyConfig
quantization_config_to_str[Int4WeightOnlyConfig] = "int4"
except:
except (ImportError, RuntimeError):
pass
try:

View File

@@ -407,9 +407,11 @@ class AxolotlInputConfig(
default=None,
json_schema_extra={"description": "No AMP (automatic mixed precision)"},
) # for non-AMP cases
tf32: bool | None = Field(
default=None,
json_schema_extra={"description": "Use CUDA tf32 - require >=ampere"},
tf32: Literal["auto"] | bool | None = Field(
default="auto",
json_schema_extra={
"description": "bool to use CUDA tf32 or 'auto' for automatic detection - require >=ampere"
},
)
float32: bool | None = None
@@ -1218,6 +1220,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return self
@model_validator(mode="after")
def check_tf32(self):
if self.tf32 == "auto":
self.tf32 = self.capabilities.tf32
return self
@model_validator(mode="after")
def check_fp8(self):
if self.fp8 and not self.capabilities.fp8:

View File

@@ -10,6 +10,7 @@ class GPUCapabilities(BaseModel):
bf16: bool = Field(default=False)
fp8: bool = Field(default=False)
tf32: bool = Field(default=False)
n_gpu: int = Field(default=1)
n_node: int = Field(default=1)
compute_capability: Optional[str] = Field(default=None)

View File

@@ -189,3 +189,125 @@ class TRLConfig(BaseModel):
"'normalize_then_sum' (GDPO): normalizes each reward independently, then sums."
},
)
# Async GRPO fields
use_data_producer: bool = Field(
default=False,
json_schema_extra={
"description": "Use the GRPODataProducer protocol for online data generation."
},
)
async_prefetch: bool = Field(
default=False,
json_schema_extra={
"description": "Generate rollouts in a background thread while training on the previous rollout."
},
)
prefetch_depth: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of rollouts to prefetch ahead of training."
},
)
vllm_sync_interval: int | None = Field(
default=None,
json_schema_extra={
"description": "Sync model weights to vLLM every N optimizer steps (async mode only)."
},
)
streaming_partial_batch: bool | None = Field(
default=None,
json_schema_extra={
"description": "Score prompt groups incrementally instead of the full batch at once."
},
)
streaming_min_groups: int | None = Field(
default=None,
json_schema_extra={
"description": "Minimum prompt groups to score per streaming chunk."
},
)
vllm_importance_sampling_correction: bool | None = Field(
default=None,
json_schema_extra={
"description": "Apply IS correction for distribution mismatch between vLLM and training model."
},
)
vllm_importance_sampling_mode: (
Literal["token_truncate", "token_mask", "sequence_truncate", "sequence_mask"]
| None
) = Field(
default=None,
json_schema_extra={
"description": "IS mode: token_truncate, token_mask, sequence_truncate, or sequence_mask."
},
)
vllm_importance_sampling_cap: float | None = Field(
default=None,
json_schema_extra={"description": "Cap C for IS ratio clipping/masking."},
)
off_policy_mask_threshold: float | None = Field(
default=None,
json_schema_extra={
"description": "KL threshold for off-policy sequence masking (OPSM). None = disabled."
},
)
use_bias_correction_kl: bool | None = Field(
default=None,
json_schema_extra={"description": "Apply IS correction to KL divergence term."},
)
reward_num_workers: int = Field(
default=1,
json_schema_extra={
"description": "Number of persistent subprocess workers for parallel reward computation. Each worker has its "
"own main thread so signal.alarm() (used by math_verify) works correctly. Work is sharded across "
"workers by prompt groups. Only used with use_data_producer=True and non-nn.Module reward functions."
},
)
replay_buffer_size: int = Field(
default=0,
json_schema_extra={
"description": "[Experimental, disabled by default] Size of the replay buffer for storing high-signal rollout "
"groups. When > 0, groups with reward variance are cached and used to replace zero-signal groups "
"(where all rewards are identical). Set to 0 to disable. Only used with use_data_producer=True."
},
)
replay_recompute_logps: bool = Field(
default=True,
json_schema_extra={
"description": "When True (default), recompute old_per_token_logps for replayed groups using the current "
"training model. This fixes the importance sampling mismatch that occurs when replaying stale data. "
"Only relevant when replay_buffer_size > 0."
},
)
reroll_start_fraction: float = Field(
default=1.0,
json_schema_extra={
"description": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts "
"(where all rewards in a group are identical) are buffered and re-injected into later batches when the "
"model is more likely to solve them. Set to 1.0 to disable. Only used with use_data_producer=True."
},
)
reroll_max_groups: int = Field(
default=1,
json_schema_extra={
"description": "Maximum number of prompt groups to replace with re-roll candidates per batch. Higher values "
"increase data utilization but reduce prompt diversity. Only used with use_data_producer=True."
},
)
skip_zero_advantage_batches: bool = Field(
default=True,
json_schema_extra={
"description": "When True, skip gradient computation for micro-batches where all advantages are zero (no learning "
"signal). This avoids the forward/backward pass entirely when no learning signal is present. The step is "
"logged with skipped_zero_adv_batches=1 for monitoring."
},
)
vllm_lora_sync: bool = Field(
default=False,
json_schema_extra={
"description": "Sync LoRA adapter to vLLM via filesystem instead of merging + NCCL broadcast. "
"Auto-selects vllm_serve_lora serve module. Syncs only LoRA adapter weights vs full merged model."
},
)

View File

@@ -128,7 +128,6 @@ class DatasetValidationMixin:
):
LOG.info(
"explicitly setting `eval_sample_packing` to match `sample_packing`",
main_process_only=True,
)
data["eval_sample_packing"] = True
@@ -254,6 +253,23 @@ class TrainingValidationMixin:
data["pad_to_sequence_len"] = True
return data
@model_validator(mode="before")
@classmethod
def set_reward_model_defaults(cls, data):
if data.get("reward_model"):
if data.get("num_labels") is None:
data["num_labels"] = 1
if not (data.get("type_of_model") or data.get("model_type")):
data["model_type"] = "AutoModelForSequenceClassification"
if data.get("process_reward_model"):
if data.get("num_labels") is None:
data["num_labels"] = 2
if not (data.get("type_of_model") or data.get("model_type")):
data["model_type"] = "AutoModelForTokenClassification"
return data
@model_validator(mode="before")
@classmethod
def check_gas_bsz(cls, data):
@@ -676,20 +692,6 @@ class LoRAValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_kernels_rl(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
) and data.get("rl"):
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
"compatible with RL at the moment."
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_kernels_trust_remote_code(cls, data):

View File

@@ -57,3 +57,10 @@ class VllmConfig(BaseModel):
default=None,
json_schema_extra={"description": "Reasoning parser for VLLM"},
)
serve_module: str | None = Field(
default=None,
json_schema_extra={
"description": "Python module for vLLM serve script. Set to 'axolotl.scripts.vllm_serve_lora' "
"for native LoRA support, or leave None for default TRL serve."
},
)

View File

@@ -0,0 +1,220 @@
"""Unit tests for async GRPO"""
import unittest
from unittest.mock import MagicMock
import torch
class TestReplayBuffer(unittest.TestCase):
"""Tests for ReplayBuffer edge cases."""
def test_add_noop_when_max_size_zero(self):
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
buf = ReplayBuffer(max_size=0)
buf.add(1.0, {"data": "test"})
self.assertEqual(len(buf), 0)
def test_add_noop_when_max_size_negative(self):
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
buf = ReplayBuffer(max_size=-1)
buf.add(1.0, {"data": "test"})
self.assertEqual(len(buf), 0)
def test_sample_returns_none_when_max_size_zero(self):
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
buf = ReplayBuffer(max_size=0)
self.assertIsNone(buf.sample(1))
def test_sample_returns_none_when_empty(self):
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
buf = ReplayBuffer(max_size=5)
self.assertIsNone(buf.sample(1))
def test_normal_add_and_sample(self):
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
buf = ReplayBuffer(max_size=3)
buf.add(1.0, {"a": 1})
buf.add(2.0, {"a": 2})
buf.add(3.0, {"a": 3})
self.assertEqual(len(buf), 3)
result = buf.sample(1)
self.assertIsNotNone(result)
self.assertEqual(len(result), 1)
def test_replaces_lowest_when_full(self):
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
buf = ReplayBuffer(max_size=2)
buf.add(1.0, {"a": 1})
buf.add(2.0, {"a": 2})
buf.add(3.0, {"a": 3}) # should replace score=1.0
self.assertEqual(len(buf), 2)
scores = sorted(item[0] for item in buf._heap)
self.assertEqual(scores, [2.0, 3.0])
class TestGRPOStrategyConflict(unittest.TestCase):
"""Tests for sequence_parallel + async_grpo conflict detection."""
def test_raises_on_both_enabled(self):
from axolotl.core.trainers.grpo import GRPOStrategy
with self.assertRaises(ValueError) as ctx:
GRPOStrategy.get_trainer_class(sequence_parallel=True, async_grpo=True)
self.assertIn("sequence_parallel", str(ctx.exception))
self.assertIn("async_grpo", str(ctx.exception))
def test_sequence_parallel_only(self):
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.trainers.grpo.trainer import (
AxolotlGRPOSequenceParallelTrainer,
)
cls = GRPOStrategy.get_trainer_class(sequence_parallel=True, async_grpo=False)
self.assertIs(cls, AxolotlGRPOSequenceParallelTrainer)
def test_async_only(self):
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.trainers.grpo.trainer import AxolotlAsyncGRPOTrainer
cls = GRPOStrategy.get_trainer_class(sequence_parallel=False, async_grpo=True)
self.assertIs(cls, AxolotlAsyncGRPOTrainer)
def test_neither(self):
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
cls = GRPOStrategy.get_trainer_class(sequence_parallel=False, async_grpo=False)
self.assertIs(cls, AxolotlGRPOTrainer)
class TestDequantizeFP8TailBlocks(unittest.TestCase):
"""Tests for FP8 dequantization with non-divisible dimensions."""
def test_exact_divisible_shape(self):
from axolotl.kernels.quantize import dequantize_fp8
W = torch.randn(256, 128, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
scale_inv = torch.ones(2, 1, dtype=torch.bfloat16)
result = dequantize_fp8(W, scale_inv)
self.assertEqual(result.shape, (256, 128))
self.assertEqual(result.dtype, torch.bfloat16)
def test_non_divisible_rows(self):
from axolotl.kernels.quantize import dequantize_fp8
# 130 rows, scale has 2 blocks (block_size ~65 for exact div, but with
# tail blocks: first block=65 rows, second=65 rows, 130%2=0 actually).
# Use 131 rows with 2 scale blocks to trigger tail handling.
W = torch.ones(131, 128, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
scale_inv = torch.tensor([[2.0], [3.0]], dtype=torch.bfloat16)
result = dequantize_fp8(W, scale_inv)
self.assertEqual(result.shape, (131, 128))
self.assertEqual(result.dtype, torch.bfloat16)
def test_non_divisible_cols(self):
from axolotl.kernels.quantize import dequantize_fp8
W = torch.ones(128, 200, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
scale_inv = torch.ones(1, 2, dtype=torch.bfloat16)
result = dequantize_fp8(W, scale_inv)
self.assertEqual(result.shape, (128, 200))
def test_scalar_scale(self):
from axolotl.kernels.quantize import dequantize_fp8
W = torch.ones(64, 64, dtype=torch.bfloat16).to(torch.float8_e4m3fn)
scale_inv = torch.tensor(2.0, dtype=torch.bfloat16)
result = dequantize_fp8(W, scale_inv)
self.assertEqual(result.shape, (64, 64))
class TestLoraFP8Guard(unittest.TestCase):
"""Tests that get_lora_parameters only uses weight_scale_inv for FP8 weights."""
def test_non_fp8_weight_skips_scale_inv(self):
"""Non-FP8 weight should NOT pick up weight_scale_inv as quant_state."""
from axolotl.kernels.lora import get_lora_parameters
proj = MagicMock()
proj.disable_adapters = True
base_layer = MagicMock(spec=[]) # empty spec to control attrs precisely
# Use a real tensor for weight (bf16, no quant_state attr)
base_layer.weight = torch.randn(64, 64, dtype=torch.bfloat16)
base_layer.bias = None
base_layer.weight_scale_inv = torch.ones(1) # should NOT be used for bf16
proj.base_layer = base_layer
W, b, quant_state, A, B, s = get_lora_parameters(proj)
# quant_state should be None since weight is bf16, not FP8
self.assertIsNone(quant_state)
def test_fp8_weight_uses_scale_inv(self):
"""FP8 weight should pick up weight_scale_inv as quant_state."""
from axolotl.kernels.lora import get_lora_parameters
proj = MagicMock()
proj.disable_adapters = True
base_layer = MagicMock()
proj.base_layer = base_layer
# FP8 weight
base_layer.weight = torch.randn(64, 64, dtype=torch.bfloat16).to(
torch.float8_e4m3fn
)
base_layer.bias = None
scale_inv = torch.ones(1)
base_layer.weight_scale_inv = scale_inv
W, b, quant_state, A, B, s = get_lora_parameters(proj)
self.assertIs(quant_state, scale_inv)
class TestValidateQuantPatchRestore(unittest.TestCase):
"""Test that validate_quantization_for_training is restored after trainer creation."""
def test_patch_restored_on_success(self):
"""Monkeypatch should be restored even after successful trainer creation."""
import transformers.trainer as _trainer_module
original = _trainer_module.validate_quantization_for_training
# After the build() method runs, original should be restored.
# We can't easily test the full build(), but we can test the pattern.
_orig = _trainer_module.validate_quantization_for_training
_trainer_module.validate_quantization_for_training = lambda model: None
try:
pass # simulate trainer_cls() succeeding
finally:
_trainer_module.validate_quantization_for_training = _orig
self.assertIs(_trainer_module.validate_quantization_for_training, original)
def test_patch_restored_on_error(self):
"""Monkeypatch should be restored even if trainer creation raises."""
import transformers.trainer as _trainer_module
original = _trainer_module.validate_quantization_for_training
_orig = _trainer_module.validate_quantization_for_training
_trainer_module.validate_quantization_for_training = lambda model: None
try:
raise ValueError("test error")
except ValueError:
pass
finally:
_trainer_module.validate_quantization_for_training = _orig
self.assertIs(_trainer_module.validate_quantization_for_training, original)
if __name__ == "__main__":
unittest.main()

Some files were not shown because too many files have changed in this diff Show More