Compare commits

...

31 Commits

Author SHA1 Message Date
Wing Lian
a4a3b618e7 force torch to match when installing fa and deepspeed using uv 2026-03-04 10:00:08 -05:00
Wing Lian
b6b8db805a fix python version typo for building 3.11 (#3454) 2026-03-04 09:53:35 -05:00
Wing Lian
653f90be25 Add torch 2.10.0 to unit tests and use python 3.14 (#3450)
* Add torch 2.10.0 to unit tests and use python 3.14

* hold on python 3.14 checks due to mistral common

* add base option to matrix
2026-03-03 13:01:52 -05:00
NanoCode012
945c8aeb10 Fix: quantize and target moe layers in transformers v5 for adapters and many misc fixes (#3439)
* fix: saving clones state dict

* fix: apply fix for only CP mode

* fix: add dropout check when using lora target param

* fix: re-add patch from transformers PR #39866

* feat: add moe quant to test by ved

* fix: try match target param properly end with

* fix: clear cache per param quant

* fix: attempt on-load quantize experts instead of post-load

* fix: attempt disable async load

* chore: add log

* chore: adjust log

* fix: remove cuda alloc for moe and enable async load

* chore: remove leftover logs

* chore: add extra empty cache

* fix(doc): clarify support

* fix: handle fsdp2 for paramwrapper dtensor

* feat: attempt to quant experts in 8bit mode too

* feat: attempt to release bf16 experts from vram

* feat: upgrade cce

* fix: fsdp2 init_sharded_param load int8/uint4 dtensor as
require_grad=true on init

* fix: remove unnecessary gc and empty cache

* Revert "fix: remove unnecessary gc and empty cache"

This reverts commit 1d54518990.

* fix: do not call full_tensor on non-dtensors

* fix: attempt to address fsdp2 with quant exp high loss

* fix: attempt lora quant experts wrong dim

* fix: ensure require_grad patch applied for lora 8bit

* fix: attempt lora 8bit fsdp2

* fix: attribute access on save for lora 8bit fsdp2

* fix: wrong weight attrib access

* chore(refactor): add config, re-arrange position of patches, clean
comments

* feat: add example docs

* chore: cherry pick trinity fixes from PR 3399

* chore: comments refactor; add guards

* fix: guard using wrong key

* fix: mamba save does not accept main process param

* fix: guard prevent double hook

* fix: move gc to upper scope

* chore: add comment on proxy forward patch

* fix: add comment to clarify

* feat: add test idempotency

* fix: AttributeError: `e_score_correction_bias` is not an nn.Parameter

* fix: AttributeError: 'NoneType' object has no attribute 'to'

* fix: update docs on cpu_ram_efficient_loading
2026-03-03 10:06:23 -05:00
NanoCode012
e672d37f33 fix: qwen3-next to use fla causal-conv1d to support packing (#3437
* fix: qwen3-next to use fla causal-conv1d to support packing

* fix: causal import and update doc for v5

* fix: hard fail for packing without fla
2026-03-03 09:26:46 -05:00
Wing Lian
77828d3559 uv cloud image should use uv w pip (#3449) 2026-03-02 16:39:26 -05:00
Wing Lian
4272817109 don't install torch ao on arm64 (#3448) 2026-03-02 14:24:54 -05:00
Manas Vardhan
474208b794 fix: Save de-duplicated dataset during pre-processing (#3427)
* fix: run deduplication before saving dataset during preprocessing

Move deduplicate_and_log_datasets call before save_preprocessed_dataset
in both SFT and RL data loading pipelines. This ensures the saved
preprocessed dataset is already de-duplicated, so subsequent loads
from cache don't contain duplicates.

Fixes #2719

* fix: include deduplication flag in dataset hash and warn on skip_prepare_dataset+dedup

- Add dataset_exact_deduplication to the hash string in
  generate_dataset_hash_from_config so cached datasets are invalidated
  when the dedup setting changes.
- Log a warning when skip_prepare_dataset=True and
  dataset_exact_deduplication=True, since dedup will be silently
  skipped in that configuration (both SFT and RL paths).

* fix: add ValueError for skip_prepare+dedup, fix test mock target and formatting

- Add config validator (check_deduplication_with_skip_prepare) that raises
  ValueError when skip_prepare_dataset=True and dataset_exact_deduplication=True
- Replace runtime warnings in sft.py/rl.py with the validator check
- Fix RL test: patch axolotl.utils.data.rl.load_tokenizer instead of
  axolotl.loaders.load_tokenizer to properly mock the imported reference
- Fix ruff lint (remove unused imports) and formatting issues

* refactor: inline deduplicate function per review feedback

* fix test fixture, lint

---------

Co-authored-by: ManasVardhan <manasvardhan@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-03-02 12:55:59 -05:00
Wing Lian
444020b332 mark slow tests that are timing out in CI (#3428) [skip ci] 2026-03-02 12:26:30 -05:00
Wing Lian
aa88c2e30b fix uv cache subcommand (#3447) 2026-03-02 12:26:08 -05:00
NanoCode012
f447bce1db fix: do not push telemetry on non-master rank (#3438) 2026-03-02 15:31:20 +07:00
kallewoof
7f23b302d1 bug-fix: use self.optimizer if optimizer not passed to SchedulerMixin.create_scheduler() (#3435) [skip ci]
* bug-fix: use self.optimizer if optimizer not passed to SchedulerMixin.create_scheduler()

* nit: raise if self.optimizer is also unset

* optimizer properly optional in create_scheduler()
2026-03-02 15:30:07 +07:00
Wing Lian
18f26c19ef add uv axolotl builds (#3431) 2026-02-25 14:46:02 -05:00
Robert Ronan
2b6f4a6c9b Fix: excess_length_strategy truncation method (#3401)
* Add test cases to verify that the problem exists in the underlying

* Update the handle_long_sequences function to correctly use Map instead of filter for the truncation strategy. Also remove the minimal length filtering from the truncate_long_samples function, and run it separately and before.

* fix: refactor and add test truncate for non-input id fields

* fix: refactor long seq handling fn

* fix: refactor duplicate fn and simplify route

* add additional tests and make them work on mac

* handle logging exception on empty datasets

---------

Co-authored-by: 2ndset bot <bot@2ndset.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-02-25 11:31:11 +07:00
madScientist10
8f54b4eb25 fix: pass revision parameter to tokenizer and processor loaders (#3388) [skip ci]
* fix: pass revision parameter to tokenizer and processor loaders

* fix: address revision=None passed to .from_pretrained

* add tests and address review feedback for revision parameter

- Reformat modify_tokenizer_files signature and from_pretrained call
- Use kwargs pattern for modify_tokenizer_files call to avoid passing None revision
- Add 6 unit tests for revision parameter in tokenizer/processor loaders

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-02-25 11:11:20 +07:00
VED
a131e4d0e5 sample gen support sft (#3240) [skip ci]
* add:parameters + callback

* sft core + logging

* indentation fix

* logger fix

* loger fix in sft

* gen sample on eval

* lint

* deprecation
2026-02-25 11:10:57 +07:00
Wing Lian
1791d87b6f build axolotl images with torch 2.10.0 (#3430) 2026-02-24 22:35:25 -05:00
Wing Lian
b40803da51 build base images for torch 2.10.0 (#3429) 2026-02-24 20:32:34 -05:00
Wing Lian
68f1b7004c ScatterMoE LoRA support (#3410)
* scattermoe lora support

* fsdp, bf16, dim fixes

* expert weights aren't needed in save for bwd since they are frozen

* use sonicmoe optim options

* update save model from upstream

* fixes per code review feedback and add tests

* revert removal of CP fix

* misc fixes
2026-02-24 14:59:55 -05:00
NanoCode012
08441fed17 fix: set allowed values for adapter config (#3415) 2026-02-23 11:39:53 -05:00
NanoCode012
86ca1e27c0 fix: update MistralProcessor to be v5 compat (#3423)
* fix: update MistralProcessor to be v5 compat

* feat: add test for mistral3 processor

* chore: comment
2026-02-23 11:39:13 -05:00
Manas Vardhan
5ed455715e feat: support dot-notation CLI args for nested config options (#3419)
* feat: support dot-notation CLI args for nested config options

Add support for overriding nested config fields (like TRL config) via
CLI using dot-notation, e.g.:
  axolotl train grpo.yaml --trl.vllm-server-host=10.0.0.1 --trl.beta=0.1

Changes:
- args.py: Detect BaseModel subclass fields and generate dot-notation
  CLI options (--parent.child) that map to double-underscore kwargs
  (parent__child). Also fix _strip_optional_type for Python 3.10+
  union syntax (X | None).
- config.py: Handle double-underscore kwargs in load_cfg by setting
  nested dict values on the config.
- Add tests for nested option handling.

Fixes #2702

* Address CodeRabbit review: fix string parent bug, add type hints and docstring

Signed-off-by: Manas Vardhan <manasvardhan@gmail.com>

* Add type coercion for CLI kwargs and fix pre-commit issues

- Add _coerce_value() for YAML-style type inference on string CLI args
- When existing config value has a type (int/float/bool), cast to match
- When no existing value, infer type from string (true/false, ints, floats, null)
- Apply coercion to both flat and nested (dot-notation) kwargs
- Fix unused pytest import (pre-commit/ruff)
- Update tests to pass string values (matching real CLI behavior)
- Add dedicated TestCoerceValue test class

Addresses maintainer feedback on type casting for nested kwargs.

---------

Signed-off-by: Manas Vardhan <manasvardhan@gmail.com>
2026-02-23 10:10:06 -05:00
Lorenzo Baraldi
3f30572d4a Fix typo in dataset_processes field (#3426)
* Fix typo in dataset_processes field

* fix: use updated config name

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2026-02-23 14:18:37 +07:00
NanoCode012
43d60c7439 bump cut-cross-entropy to 58d6572 (#3424) 2026-02-20 14:24:51 -05:00
Wing Lian
0ea252d392 update to trackio 0.16.1 (#3425) [skip ci] 2026-02-20 14:24:33 -05:00
Wing Lian
29722dec60 use bunnycdn for CI assets (#3422) [skip ci] 2026-02-20 00:09:25 -05:00
NanoCode012
7fbedbd300 fix(doc): add limitation for unfrozen_parameters (#3416) 2026-02-19 18:32:26 -05:00
Wing Lian
145ffc9be1 upgrade transformers to 5.2.0 and torchao to 0.16.0 (#3407)
* upgrade transformers to 5.1.0 and torchao to 0.16.0

* upgrade trl for parity

* handle trl api changes

* orpo doesn't have max_prompt_len to check anymore

* cpoconfig doesn't take max_prompt_length and fix cpu offload

* slow fsdp1 test

* triton min 3.4.0 and liger to 0.7.0

* use transformers main for now for zero3 fix

* handle group_by_length change

* fix changes upstream

* mark skip flaky test

* use transformers latest release 5.2.0
2026-02-19 18:27:27 -05:00
NanoCode012
4f1b5ad29f fix: clarify how to use lm_eval plugin (#3404) [skip ci] 2026-02-15 07:52:30 -05:00
NanoCode012
d6a2532dd7 feat(doc): clarify how to use scattermoe (#3408) [skip ci]
* feat(doc): clarify how to use scattermoe

* chore: fix wording
2026-02-15 07:51:28 -05:00
Wing Lian
5eb265513c fix generic patch for cce (#3405) 2026-02-12 08:58:04 -05:00
87 changed files with 7754 additions and 370 deletions

View File

@@ -51,14 +51,30 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "129"
cuda_version: 12.9.1
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
python_version: "3.11"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "129"
# cuda_version: 12.9.1
# cudnn_version: ""
# python_version: "3.12"
# pytorch: 2.9.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-base"
# platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
@@ -75,6 +91,14 @@ jobs:
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "128"
# cuda_version: 12.8.1
# cudnn_version: ""
@@ -157,14 +181,30 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "129"
cuda_version: 12.9.1
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
python_version: "3.11"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "129"
# cuda_version: 12.9.1
# cudnn_version: ""
# python_version: "3.12"
# pytorch: 2.9.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-uv-base"
# platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
@@ -181,6 +221,14 @@ jobs:
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -34,16 +34,28 @@ jobs:
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 129
cuda_version: 12.9.1
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras:
# platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
@@ -86,6 +98,77 @@ jobs:
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-uv:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-uv
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
- name: Build and export to Docker
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
file: ./docker/Dockerfile-uv
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
@@ -112,16 +195,28 @@ jobs:
axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 129
cuda_version: 12.9.1
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras:
# platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
@@ -159,6 +254,73 @@ jobs:
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud-uv:
needs: build-axolotl-uv
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-cloud-uv
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-cloud-uv
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud-no-tmux:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}

View File

@@ -37,7 +37,7 @@ jobs:
id: hf-cache-restore-s3
run: |
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
- name: Setup Python
uses: actions/setup-python@v5

View File

@@ -54,13 +54,13 @@ jobs:
strategy:
fail-fast: false
matrix:
python_version: ["3.11", "3.12"]
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
exclude:
- python_version: "3.12"
pytorch_version: "2.8.0"
- python_version: "3.12"
pytorch_version: "2.9.0"
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
# exclude:
# - python_version: "3.14"
# pytorch_version: "2.8.0"
# - python_version: "3.14"
# pytorch_version: "2.9.1"
timeout-minutes: 20
steps:
@@ -75,7 +75,7 @@ jobs:
id: hf-cache-restore-s3
run: |
mkdir -p ~/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
ls -ltr ~/.cache/huggingface/hub/
- name: Setup Python
@@ -149,13 +149,13 @@ jobs:
strategy:
fail-fast: false
matrix:
python_version: ["3.11", "3.12"]
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
exclude:
- python_version: "3.12"
pytorch_version: "2.8.0"
- python_version: "3.12"
pytorch_version: "2.9.0"
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
# exclude:
# - python_version: "3.14"
# pytorch_version: "2.8.0"
# - python_version: "3.14"
# pytorch_version: "2.9.1"
timeout-minutes: 20
steps:
@@ -170,7 +170,7 @@ jobs:
id: hf-cache-restore-s3
run: |
mkdir -p ~/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
ls -ltr ~/.cache/huggingface/hub/
- name: Setup Python
@@ -264,8 +264,8 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 129
cuda_version: 12.9.1
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.9.1
num_gpus: 1
@@ -326,6 +326,12 @@ jobs:
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.10.0
num_gpus: 1
axolotl_extras:
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
@@ -371,7 +377,7 @@ jobs:
include:
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
python_version: "3.11"
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:

View File

@@ -59,34 +59,18 @@ RUN git lfs install --skip-repo && \
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge
RUN case "$PYTORCH_VERSION" in \
2.9.[0-9]*) \
if [ "$CUDA" = "128" ]; then \
if [ "$TARGETARCH" = "amd64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl"; \
WHL_VERSION="v0.5.4"; \
elif [ "$TARGETARCH" = "arm64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl"; \
WHL_VERSION="v0.6.4"; \
else \
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
fi; \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
pip3 install --no-cache-dir ${WHL_FILE}; \
rm ${WHL_FILE}; \
elif [ "$CUDA" = "130" ]; then \
if [ "$TARGETARCH" = "amd64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl"; \
WHL_VERSION="v0.5.4"; \
elif [ "$TARGETARCH" = "arm64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl"; \
WHL_VERSION="v0.6.4"; \
else \
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
fi; \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
pip3 install --no-cache-dir ${WHL_FILE}; \
rm ${WHL_FILE}; \
fi \
;; \
esac
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="x86_64" ;; \
arm64) ARCH_TAG="aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
pip3 install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"

View File

@@ -0,0 +1,30 @@
ARG BASE_TAG=main
FROM axolotlai/axolotl-uv:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
EXPOSE 8888
EXPOSE 22
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
COPY scripts/motd /etc/motd
RUN uv pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt update && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/* && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh && \
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
CMD ["sleep", "infinity"]

47
docker/Dockerfile-uv Normal file
View File

@@ -0,0 +1,47 @@
ARG BASE_TAG=main-base
FROM axolotlai/axolotl-base-uv:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ARG PYTORCH_VERSION="2.1.2"
ARG TARGETARCH
ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/*
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \
python scripts/unsloth_install.py --uv | sh && \
python scripts/cutcrossentropy_install.py --uv | sh && \
uv pip install pytest && \
uv cache clean
# fix so that git fetch/pull from remote works with shallow clone
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch && \
git config --global credential.helper store
COPY .axolotl-complete.bash /root/.axolotl-complete.bash
RUN chmod +x /root/.axolotl-complete.bash && \
echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc

View File

@@ -6,6 +6,7 @@ ARG TARGETARCH
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ARG TARGETARCH
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="2.6.0"
ARG CUDA="126"
@@ -39,28 +40,18 @@ RUN if [ "$TARGETARCH" = "amd64" ]; then \
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
fi
RUN case "$PYTORCH_VERSION" in \
2.9.[0-9]*) \
if [ "$TARGETARCH" = "amd64" ]; then \
if [ "$CUDA" = "128" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
fi \
elif [ "$TARGETARCH" = "arm64" ]; then \
if [ "$CUDA" = "128" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
fi \
fi \
;; \
esac
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="x86_64" ;; \
arm64) ARCH_TAG="aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
uv pip install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"

View File

@@ -210,6 +210,8 @@ axolotl lm-eval config.yml
Configuration options:
```yaml
lm_eval_model: # model to evaluate (local or hf path)
# List of tasks to evaluate
lm_eval_tasks:
- arc_challenge
@@ -218,7 +220,7 @@ lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
```
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
See [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details.
### delinearize-llama4

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583\""
]
},
{

View File

@@ -0,0 +1,77 @@
# Finetune Z.ai's GLM-4.7-Flash with Axolotl
[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model by Z.ai.
This guide shows how to fine-tune it with Axolotl.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
# QLoRA
# - no target experts (1x48GB @ ~24GiB/GPU)
# - target experts (1x48GB @ ~34GiB/GPU)
axolotl train examples/glm4.7-flash/qlora.yaml
# QLoRA FSDP2 no target experts (2x48GB @ ~29GiB/GPU)
axolotl train examples/glm4.7-flash/qlora_fsdp.yaml
```
```bash
# LoRA
# - no target experts (1x48GB @ ~35GiB/GPU)
# - target experts (1x48GB @ OOM. Projected ~45-50GiB/GPU)
axolotl train examples/glm4.7-flash/lora.yaml
# LoRA FSDP2 no target experts (2x48GB @ ~43GiB/GPU)
axolotl train examples/glm4.7-flash/lora_fsdp.yaml
```
### Expert LoRA
To also apply LoRA adapters to expert weights, add `lora_target_parameters` to your config.
Note: `lora_dropout` must be `0` when using `lora_target_parameters`.
```yaml
lora_target_parameters:
- mlp.experts.gate_up_proj
- mlp.experts.down_proj
# - mlp.gate.weight # router, untested but should work, not normally targeted
```
## Limitations
- **FSDP VRAM**: FSDP2 may use more VRAM per GPU than single GPU training. We suspect not all layers are properly sharded across ranks.
- **FSDP initial spike**: FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps that then drops. FSDP QLoRA (4-bit) does not exhibit this.
- **cpu_ram_efficient_loading**: Must be set to `false` with FSDP2 — causes hang otherwise.
- **lora_target_linear**: Incompatible for this model.
- **LoRA kernels**: Incompatible with this model due to non-standard attention projections (DSA). Must be explicitly disabled (`lora_*_kernel: false`).
### TIPS
- For inference, the official Z.ai team recommends these default settings (most tasks):
- `temperature: 1.0`
- `top_p: 0.95`
- `max_new_tokens: 131072`
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy, so we have not tested this.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)
- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,65 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-lora-8bit-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -0,0 +1,75 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-lora-8bit-fsdp-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
fsdp_config:
fsdp_version: 2
offload_params: false
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -0,0 +1,65 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -0,0 +1,75 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-qlora-fsdp-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
fsdp_config:
fsdp_version: 2
offload_params: false
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -6,30 +6,13 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Install Qwen3-Next transformers commit
```bash
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Install FLA for improved performance
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
```
4. Run the finetuning example:
@@ -38,7 +21,7 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
```
This config uses about 45.62 GiB VRAM.
This config uses about ~47 GiB (no target experts) and ~71GiB (target experts) VRAM.
Let us know how it goes. Happy finetuning! 🚀

View File

@@ -9,6 +9,8 @@ plugins:
load_in_8bit: false
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
@@ -25,7 +27,7 @@ sample_packing: true
lora_r: 16
lora_alpha: 8
lora_dropout: 0.05
lora_dropout: 0
lora_target_modules:
- linear_attn.in_proj_ba
- linear_attn.in_proj_qkvz
@@ -34,12 +36,19 @@ lora_target_modules:
- shared_expert.down_proj
- shared_expert.gate_proj
- shared_expert_gate
- mlp.gate
- q_proj
- v_proj
- k_proj
- o_proj
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:

View File

@@ -8,13 +8,15 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Run the finetuning example:
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
```
This config uses about 24.9 GiB VRAM.
This config uses about 24.9 GiB VRAM (w/o CCE).
Let us know how it goes. Happy finetuning! 🚀
@@ -29,10 +31,6 @@ Let us know how it goes. Happy finetuning! 🚀
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Limitations
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future.
## Related Resources
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)

View File

@@ -1,5 +1,4 @@
base_model: arcee-ai/Trinity-Nano-Preview
trust_remote_code: true
revision_of_model: 2ee94b0
# Automatically upload checkpoint and final model to HF

View File

@@ -63,3 +63,5 @@ docstring-code-format = false
[tool.uv.extra-build-dependencies]
axolotl = ["huggingface_hub"]
flash-attn = [{ requirement = "torch", match-runtime = true }]
deepspeed = [{ requirement = "torch", match-runtime = true }]

View File

@@ -2,25 +2,25 @@
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.49.1
triton>=3.0.0
triton>=3.4.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
liger-kernel==0.6.4
liger-kernel==0.7.0
# END section
packaging==26.0
huggingface_hub>=1.1.7
peft>=0.18.1
tokenizers>=0.22.1
transformers==5.0.0
transformers==5.2.0
accelerate==1.12.0
datasets==4.5.0
deepspeed>=0.18.3
trl==0.27.1
trl==0.28.0
hf_xet==1.2.0
kernels==0.11.5
kernels==0.12.1
trackio>=0.13.0
trackio>=0.16.1
typing-extensions>=4.15.0
optimum==1.16.2
@@ -63,7 +63,7 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.13.0
torchao==0.16.0
openenv-core==0.1.0
schedulefree==1.4.1

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"'
)

View File

@@ -26,6 +26,11 @@ def parse_requirements(extras_require_map):
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
install_xformers = platform.machine() != "aarch64"
if platform.machine() == "aarch64":
# skip torchao on ARM64
_install_requires = [
req for req in _install_requires if "torchao" not in req
]
if "Darwin" in platform.system():
# skip packages not compatible with OSX
skip_packages = [

View File

@@ -5,7 +5,7 @@ import os
import tempfile
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Union
from typing import Any, Optional, Union
from urllib.parse import urlparse
import requests
@@ -32,6 +32,63 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = get_logger(__name__)
def _coerce_value(value: Any, existing: Optional[Any] = None) -> Any:
"""Coerce a string CLI value to its most likely Python type.
If an existing value is present in the config, its type is used to guide
casting. Otherwise, YAML-style inference is applied: booleans, ints,
floats, and None literals are recognised automatically.
Args:
value: The raw value (typically a string from the CLI).
existing: An optional existing config value whose type guides coercion.
Returns:
The value cast to the inferred or expected type.
"""
if not isinstance(value, str):
return value
# If the config already has a typed value, cast to match
if existing is not None:
if isinstance(existing, bool):
return value.lower() in ("true", "1", "yes")
if isinstance(existing, int):
try:
return int(value)
except (ValueError, TypeError):
return value
if isinstance(existing, float):
try:
return float(value)
except (ValueError, TypeError):
return value
# For other types (str, list, dict, etc.), return as-is
return value
# No existing value -- use YAML-style inference
lower = value.lower()
if lower in ("true", "yes"):
return True
if lower in ("false", "no"):
return False
if lower in ("null", "none", "~"):
return None
# Try int then float
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
pass
return value
API_KEY_FIELDS = {"comet_api_key"}
TELEMETRY_MANAGER = TelemetryManager.get_instance()
@@ -208,13 +265,37 @@ def load_cfg(
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
cfg_keys = cfg.keys()
# Separate nested (dot-notation) kwargs from flat kwargs
nested_kwargs: dict[str, dict[str, Any]] = {}
flat_kwargs: dict[str, Any] = {}
for key, value in kwargs.items():
if "__" in key:
parent, child = key.split("__", 1)
nested_kwargs.setdefault(parent, {})[child] = value
else:
flat_kwargs[key] = value
# Apply flat kwargs
for key, value in flat_kwargs.items():
# If not strict, allow writing to cfg even if it's not in the yml already
if key in cfg_keys or not cfg.strict:
if isinstance(cfg[key], bool):
cfg[key] = bool(value)
else:
cfg[key] = value
cfg[key] = _coerce_value(value, cfg.get(key))
# Apply nested kwargs (e.g., trl__beta -> cfg.trl.beta)
for parent, children in nested_kwargs.items():
if parent not in cfg_keys and cfg.strict:
continue
if cfg[parent] is None:
cfg[parent] = {}
if not isinstance(cfg[parent], dict):
LOG.warning(
"Overwriting non-dict value for '%s' with nested CLI overrides", parent
)
cfg[parent] = {}
for child_key, child_value in children.items():
existing_child = cfg[parent].get(child_key)
cfg[parent][child_key] = _coerce_value(child_value, existing_child)
try:
device_props = torch.cuda.get_device_properties("cuda")

View File

@@ -2,7 +2,7 @@
import dataclasses
from functools import wraps
from types import NoneType
from types import NoneType, UnionType
from typing import Any, Callable, Type, Union, get_args, get_origin
import click
@@ -20,7 +20,8 @@ def _strip_optional_type(field_type: type | str | None):
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
returns the input type unchanged.
"""
if get_origin(field_type) is Union and type(None) in get_args(field_type):
is_union = get_origin(field_type) is Union or isinstance(field_type, UnionType)
if is_union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
@@ -87,10 +88,70 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
return decorator
def _is_pydantic_model(field_type: type) -> bool:
"""Check if a type is a Pydantic BaseModel subclass."""
try:
return isinstance(field_type, type) and issubclass(field_type, BaseModel)
except TypeError:
return False
def _get_field_description(field) -> str | None:
"""Get description from a Pydantic field, checking both .description and json_schema_extra."""
if field.description:
return field.description
if field.json_schema_extra and isinstance(field.json_schema_extra, dict):
return field.json_schema_extra.get("description")
return None
def _add_nested_model_options(
function: Callable, parent_name: str, model_class: Type[BaseModel]
) -> Callable:
"""
Add Click options for all fields of a nested Pydantic model using dot-notation.
Note: Only single-level nesting is supported (e.g., ``--trl.beta``).
Deeper nesting (e.g., ``--trl.scheduler.warmup``) is not handled.
Args:
function: Click command function to add options to.
parent_name: Parent field name (e.g., "trl").
model_class: Nested Pydantic model class.
Returns:
Function with added Click options.
"""
for sub_name, sub_field in reversed(model_class.model_fields.items()):
sub_type = _strip_optional_type(sub_field.annotation)
# Use dot notation: --parent.sub_field
cli_name = f"{parent_name}.{sub_name}".replace("_", "-")
# The kwarg name uses double-underscore as separator
param_name = f"{parent_name}__{sub_name}"
description = _get_field_description(sub_field)
if sub_type is bool:
option_name = f"--{cli_name}/--no-{cli_name}"
function = click.option(
option_name, param_name, default=None, help=description
)(function)
else:
option_name = f"--{cli_name}"
click_type = {str: str, int: int, float: float}.get(sub_type)
function = click.option(
option_name, param_name, default=None, type=click_type, help=description
)(function)
return function
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
"""
Create Click options from the fields of a Pydantic model.
For fields whose type is itself a Pydantic BaseModel, dot-notation CLI options are
generated for each sub-field (e.g., ``--trl.beta=0.1``).
Args:
config_class: PyDantic model with fields to parse from the CLI
@@ -103,6 +164,11 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
for name, field in reversed(config_class.model_fields.items()):
field_type = _strip_optional_type(field.annotation)
# Handle nested Pydantic models with dot-notation options
if _is_pydantic_model(field_type):
function = _add_nested_model_options(function, name, field_type)
continue
if field_type is bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"

View File

@@ -18,4 +18,7 @@ MOE_ARCH_BLOCK = {
"gpt_oss": "GptOssDecoderLayer",
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
"afmoe": "AfmoeMoE",
"glm4_moe": "Glm4MoeDecoderLayer",
"glm4_moe_lite": "Glm4MoeLiteDecoderLayer",
"glm_moe_dsa": "GlmMoeDsaDecoderLayer",
}

View File

@@ -122,6 +122,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
ColabCallback = colab_inference_post_train_callback(trainer)
callbacks.append(ColabCallback(self.cfg))
if getattr(self.cfg, "generate_samples", False):
from axolotl.utils.callbacks.generation import SFTGenerationCallback
callbacks.append(SFTGenerationCallback(trainer))
LOG.info("SFT sample generation enabled")
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks
@@ -246,7 +252,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
ddp_find_unused_parameters
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
if self.cfg.group_by_length:
training_arguments_kwargs["train_sampling_strategy"] = "group_by_length"
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)

View File

@@ -11,7 +11,6 @@ from axolotl.core.trainers import (
)
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback
@@ -53,6 +52,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls_args = [self.model]
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1
)
@@ -133,21 +134,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
# Handle when max_prompt_length == max_length from defaults
# CPOTrainer requires strictly less than
if (
training_args_kwargs["max_prompt_length"]
== training_args_kwargs["max_length"]
):
training_args_kwargs["max_prompt_length"] -= 1
blocklist_args_kwargs.append("max_prompt_length")
elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig
blocklist_args_kwargs.append("max_prompt_length")
elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
blocklist_args_kwargs = ["max_prompt_length"]
blocklist_args_kwargs.append("max_prompt_length")
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
@@ -157,6 +154,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
)
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()

View File

@@ -719,13 +719,20 @@ class AxolotlTrainer(
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}")
if state_dict is None:
state_dict = self.accelerator.get_state_dict(self.model)
if state_dict is not None:
# fix for Context Parallel save: CP eval invalidates tensor storage
# pointers, so clone to CPU to get fresh valid storage for safetensors
if (
state_dict is not None
and self.axolotl_cfg
and self.axolotl_cfg.context_parallel_size
and self.axolotl_cfg.context_parallel_size > 1
):
state_dict = {
k: v.clone() if isinstance(v, torch.Tensor) else v
k: v.detach().cpu() if isinstance(v, torch.Tensor) else v
for k, v in state_dict.items()
}
supported_classes = (
(PreTrainedModel,)
if not is_peft_available()
@@ -736,6 +743,7 @@ class AxolotlTrainer(
if not isinstance(self.model, supported_classes):
if state_dict is None:
state_dict = self.model.state_dict()
if isinstance(
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
supported_classes,
@@ -745,6 +753,7 @@ class AxolotlTrainer(
).save_pretrained(
output_dir,
state_dict=state_dict,
is_main_process=self.accelerator.is_main_process,
)
else:
LOG.info(
@@ -772,11 +781,7 @@ class AxolotlTrainer(
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
self.data_collator.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -57,16 +57,18 @@ class AxolotlDPOTrainer(
def tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
max_prompt_length: int | None = None,
max_completion_length: int | None = None,
add_special_tokens: bool = True,
is_chat: bool = False,
) -> Dict:
res = DPOTrainer.tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
max_prompt_length=max_prompt_length,
max_completion_length=max_completion_length,
add_special_tokens=add_special_tokens,
is_chat=is_chat,
)
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:

View File

@@ -25,7 +25,7 @@ class SchedulerMixin(Trainer):
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
self, num_training_steps: int, optimizer: None | torch.optim.Optimizer = None
) -> LRScheduler:
"""
Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or
@@ -45,6 +45,13 @@ class SchedulerMixin(Trainer):
and self.args.cosine_min_lr_ratio is not None
)
if optimizer is None:
if self.optimizer is None:
raise ValueError(
"Optimizer must be set before calling create_scheduler or passed as an argument."
)
optimizer = self.optimizer
# fmt: off
if self.lr_scheduler is None: # type: ignore
# fmt: on

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"
```
## Usage
@@ -31,6 +31,7 @@ plugins:
## Supported Models
- afmoe
- apertus
- arcee
- cohere
@@ -51,6 +52,7 @@ plugins:
- glm4v
- glm4v_moe
- glm_image
- glm_moe_dsa
- gpt_oss
- granite
- granitemoe
@@ -76,14 +78,19 @@ plugins:
- olmo
- olmo2
- olmo3
- olmoe
- phi
- phi3
- phi4_multimodal
- qwen2
- qwen2_5_vl
- qwen2_moe
- qwen2_vl
- qwen2_5_vl
- qwen3
- qwen3_5
- qwen3_5_text
- qwen3_5_moe
- qwen3_5_moe_text
- qwen3_moe
- qwen3_next
- qwen3_vl

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"`'
)
@@ -104,7 +104,7 @@ class CutCrossEntropyPlugin(BasePlugin):
def patch_llama_like(
self,
model_type: str,
model_type_to_patch: str,
) -> None:
"""
Generic patch for model architectures with causal lm similar to llama
@@ -112,7 +112,10 @@ class CutCrossEntropyPlugin(BasePlugin):
from cut_cross_entropy.transformers.patch import PATCH_FNS
def patch_generic(
maybe_model, patch_options, model_type: str, remote_model_id: str | None
maybe_model,
patch_options,
remote_model_id: str | None,
model_type: str,
):
import cut_cross_entropy.transformers.llama
from cut_cross_entropy.transformers.llama import cce_forward
@@ -136,11 +139,13 @@ class CutCrossEntropyPlugin(BasePlugin):
f"Error: {str(e)}"
) from e
if model_type not in PATCH_FNS:
if model_type_to_patch not in PATCH_FNS:
LOG.warning_once(
"Setting up generic cce patch for model type: %s", model_type
"Setting up generic cce patch for model type: %s", model_type_to_patch
)
LOG.warning_once(
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected."
f"Generic Cut Cross Entropy + {model_type_to_patch} support is experimental and may not work as expected."
)
PATCH_FNS[model_type_to_patch] = partial(
patch_generic, model_type=model_type_to_patch
)
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)

View File

@@ -0,0 +1,46 @@
# Kernels Integration
MoE (Mixture of Experts) kernels speed up training for MoE layers and reduce VRAM costs. In transformers v5, `batched_mm` and `grouped_mm` were integrated as built-in options via the `experts_implementation` config kwarg:
```python
class ExpertsInterface(GeneralInterface):
_global_mapping = {
"batched_mm": batched_mm_experts_forward,
"grouped_mm": grouped_mm_experts_forward,
}
```
In our custom integration, we add support for **ScatterMoE**, which is even more efficient and faster than `grouped_mm`.
## Usage
Add the following to your axolotl YAML config:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
```
**Important:** Setting `experts_implementation` is incompatible with `use_scattermoe`.
## How It Works
The `KernelsPlugin` runs before model loading and:
1. Registers the ScatterMoE kernel from the [`axolotl-ai-co/scattermoe`](https://huggingface.co/axolotl-ai-co/scattermoe) Hub repo.
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
This works for any MoE model in transformers that uses a `SparseMoeBlock` class (Mixtral, Qwen2-MoE, OLMoE, etc.).
## Limitations
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.
## Note on MegaBlocks
We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.

View File

@@ -33,3 +33,16 @@ class KernelsArgs(BaseModel):
data["experts_implementation"] = "eager"
return data
@model_validator(mode="before")
@classmethod
def disable_mlp_kernel_scattermoe(cls, data):
if data.get("use_scattermoe") is True:
if data.get("lora_mlp_kernel") is True:
LOG.warning(
"Disabling lora_mlp_kernel when using scattermoe due to compatibility issues."
)
data["lora_mlp_kernel"] = False
data["mlp_kernel"] = False
return data

View File

@@ -0,0 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
from . import layers
from .lora_ops import ParallelExperts
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
__all__ = [
"layers",
"ParallelExperts",
"flatten_sort_count",
"parallel_linear",
"ScatterMoELoRA",
"parallel_linear_lora",
"lora_ops",
]

View File

@@ -0,0 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
#
# Original work Copyright (c) Shawn Tan and ScatterMoE Contributors
# Adapted from https://github.com/shawntan/scattermoe
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
#
# Modifications and LoRA adaptation Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
from . import lora_ops, ops
__all__ = ["ops", "lora_ops"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,645 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/shawntan/scattermoe
# Copyright (c) Shawn Tan and ScatterMoE Contributors
# Licensed under the Apache License, Version 2.0
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
from typing import Optional
import torch
import triton
import triton.language as tl
BLOCK_M = 128
ALLOW_TF32 = True
@triton.jit
def _compute_expert_block(
E_idx,
E_mask,
M_in_idx,
N_block,
N_mask,
X_ptr,
stride_xm,
stride_xk,
W_ptr,
stride_we,
stride_wk,
stride_wn,
K,
acc,
no_k_mask,
BLOCK_K,
allow_tf32=True,
):
K_block = tl.arange(0, BLOCK_K)
X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk
W_blk_ptrs = (
W_ptr
+ K_block[:, None] * stride_wk
+ N_block[None, :] * stride_wn
+ E_idx * stride_we
)
iters = tl.cdiv(K, BLOCK_K)
for K_block_id in range(iters):
if no_k_mask:
x = tl.load(X_blk_ptrs, mask=E_mask[:, None])
w = tl.load(W_blk_ptrs, mask=N_mask[None, :])
else:
K_mask = (K_block_id * BLOCK_K + K_block) < K
x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :])
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :])
X_blk_ptrs += BLOCK_K * stride_xk
W_blk_ptrs += BLOCK_K * stride_wk
acc = tl.dot(x, w, acc, allow_tf32=allow_tf32)
return acc
def _scatter2scatter_configs():
return [
triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4),
]
@triton.autotune(
configs=_scatter2scatter_configs(),
key=["M", "N", "K"],
)
@triton.heuristics(
{
"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0,
"NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0,
}
)
@triton.jit
def _scatter2scatter(
X_ptr,
stride_xm: tl.constexpr,
stride_xk: tl.constexpr,
W_ptr,
stride_we,
stride_wk: tl.constexpr,
stride_wn: tl.constexpr,
Y_ptr,
stride_ym: tl.constexpr,
stride_yn: tl.constexpr,
B_ptr,
stride_be: tl.constexpr,
stride_bn: tl.constexpr,
grouped_idx_ptr,
expert_idxs_ptr,
# block_start_idx_ptr,
FAN_OUT: tl.constexpr,
M,
K: tl.constexpr,
N: tl.constexpr,
E: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
# OUT_M,
allow_tf32: tl.constexpr,
x_grouped: tl.constexpr,
y_grouped: tl.constexpr,
NO_K_MASK: tl.constexpr,
NO_N_MASK: tl.constexpr,
):
pid = tl.program_id(axis=0)
N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N)
M_block_id = pid // N_BLOCK_COUNT
N_block_id = pid % N_BLOCK_COUNT
M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_block < N
M_boundary_mask = M_block < (FAN_OUT * M)
E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E)
no_k_mask = K % BLOCK_K == 0
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
E_first_idx = tl.min(E_idxs)
E_last_idx = tl.minimum(tl.max(E_idxs), E - 1)
M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32)
for E_idx in range(E_first_idx, E_last_idx + 1):
E_mask = E_idxs == E_idx
E_M_idx = M_idx
if x_grouped:
M_in_idx = M_block
else:
M_in_idx = E_M_idx // FAN_OUT
acc = _compute_expert_block(
E_idx,
E_mask,
M_in_idx,
N_block,
N_mask,
X_ptr,
stride_xm,
stride_xk,
W_ptr,
stride_we,
stride_wk,
stride_wn,
K,
acc,
no_k_mask,
BLOCK_K,
allow_tf32=allow_tf32,
)
if B_ptr is not None:
B_blk_ptrs = B_ptr + E_idxs[:, None] * stride_be + N_block[None, :] * stride_bn
acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :])
if y_grouped:
M_out_idx = M_block
else:
M_out_idx = M_idx
Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
def scatter2scatter(
X,
W,
sorted_expert_idxs,
sorted_scattered_idxs,
k,
b=None,
x_grouped=False,
y_grouped=False,
out=None,
):
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
assert sorted_scattered_idxs.size(0) == X.size(0) * k
# Pre-kernel setup
y_dim = W.size(-1)
L_scattered = sorted_expert_idxs.size(0)
if out is None:
output = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype)
else:
assert out.size(0) == L_scattered and out.size(1) == y_dim
output = out
scatter2scatter_compileable(
output,
W,
X,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
b,
x_grouped,
y_grouped,
)
return output
@torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"output"})
def scatter2scatter_compileable(
output: torch.Tensor,
W: torch.Tensor,
X: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
b: Optional[torch.Tensor],
x_grouped: bool,
y_grouped: bool,
) -> None:
def grid(META):
grid_num = (
triton.cdiv(sorted_expert_idxs.size(0), META["BLOCK_M"])
* triton.cdiv(META["N"], META["BLOCK_N"]),
)
return grid_num
if b is None:
b = None
stride_be = stride_bn = 0
else:
stride_be, stride_bn = b.stride()
_scatter2scatter[grid](
# X_ptr, stride_xm, stride_xk,
X,
X.stride(0),
X.stride(1),
# W_ptr, stride_we, stride_wk, stride_wn,
W,
W.stride(0),
W.stride(1),
W.stride(2),
# Y_ptr, stride_ym, stride_yn,
output,
output.stride(0),
output.stride(1),
# B_ptr, stride_be, stride_bn
b,
stride_be,
stride_bn,
grouped_idx_ptr=sorted_scattered_idxs,
expert_idxs_ptr=sorted_expert_idxs,
# block_start_idx_ptr=padded_block_idxs,
FAN_OUT=k,
M=X.size(0),
K=X.size(1),
N=output.size(1),
E=W.size(0),
BLOCK_M=BLOCK_M,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
x_grouped=x_grouped,
y_grouped=y_grouped,
)
def _config_XtY():
return [
triton.Config(
{"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4
),
]
def group_bwd_W(DY, X, expert_offsets, E, has_bias=False):
DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype)
DW = DWt.permute(0, 2, 1)
if has_bias:
Db = torch.zeros((E, DY.size(-1)), device=DY.device, dtype=DY.dtype)
else:
Db = None
groupXtY_compileable(E, DW, Db, DY, X, expert_offsets)
return DW, Db
@torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW", "Db"})
def groupXtY_compileable(
E: int,
DW: torch.Tensor,
Db: Optional[torch.Tensor],
DY: torch.Tensor,
X: torch.Tensor,
expert_offsets: torch.Tensor,
) -> None:
def grid(META):
grid = (
E * triton.cdiv(META["K"], META["BLOCK_K"]),
triton.cdiv(META["N"], META["BLOCK_N"]),
)
return grid
if Db is None:
stride_dbe = 0
stride_dbn = 0
else:
stride_dbe, stride_dbn = Db.stride()
_groupXtY[grid](
# DY_ptr, stride_dym, stride_dyk,
DY,
DY.stride(0),
DY.stride(1),
# X_ptr, stride_xm, stride_xn,
X,
X.stride(0),
X.stride(1),
# DW_ptr, stride_dwe, stride_dwk, stride_dwn,
DW,
DW.stride(0),
DW.stride(1),
DW.stride(2),
# Db_ptr, stride_dwe, stride_dbn,
Db,
stride_dbe,
stride_dbn,
# expert_offsets_ptr,
expert_offsets,
# K: tl.constexpr, N: tl.constexpr,
M=DY.size(0),
N=DY.size(-1),
K=X.size(-1),
# ACC_TYPE: tl.constexpr,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
)
@triton.autotune(
configs=_config_XtY(),
key=["M", "N", "K"],
)
@triton.heuristics(
{
"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0,
"NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0,
}
)
@triton.jit
def _groupXtY(
DY_ptr,
stride_dym,
stride_dyk,
X_ptr,
stride_xm,
stride_xn,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
expert_offsets_ptr,
M,
K: tl.constexpr,
N: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
allow_tf32: tl.constexpr,
NO_K_MASK: tl.constexpr,
NO_N_MASK: tl.constexpr,
):
pid0 = tl.program_id(axis=0)
pid1 = tl.program_id(axis=1)
num0 = tl.num_programs(0)
num1 = tl.num_programs(1)
# pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128)
pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4)
K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)
E_idx = pid0 // K_BLOCK_COUNT
K_block_id = pid0 % K_BLOCK_COUNT
N_block_id = pid1
if E_idx == 0:
start_idx = 0
else:
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
if end_idx > start_idx:
M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M)
K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
K_mask = K_block < K
K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_block < N
N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
M_idxs = M_block
xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm
dy_blk_ptrs = (
DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk
)
if (Db_ptr is not None) and (K_block_id == 0):
_xty_and_bias(
E_idx,
start_idx,
end_idx,
M_block,
K_block,
K_mask,
N_block,
N_mask,
dy_blk_ptrs,
stride_dym,
xt_blk_ptrs,
stride_xm,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
BLOCK_M,
BLOCK_N,
BLOCK_K,
ACC_TYPE,
allow_tf32,
NO_K_MASK,
NO_N_MASK,
compute_bias=True,
)
else:
_xty_and_bias(
E_idx,
start_idx,
end_idx,
M_block,
K_block,
K_mask,
N_block,
N_mask,
dy_blk_ptrs,
stride_dym,
xt_blk_ptrs,
stride_xm,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
BLOCK_M,
BLOCK_N,
BLOCK_K,
ACC_TYPE,
allow_tf32,
NO_K_MASK,
NO_N_MASK,
compute_bias=False,
)
@triton.jit
def _xty_and_bias(
E_idx,
start_idx,
end_idx,
M_block,
K_block,
K_mask,
N_block,
N_mask,
dy_blk_ptrs,
stride_dym,
xt_blk_ptrs,
stride_xm,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
BLOCK_M,
BLOCK_N,
BLOCK_K,
ACC_TYPE,
allow_tf32,
NO_K_MASK,
NO_N_MASK,
compute_bias: tl.constexpr,
):
if compute_bias:
db_acc = tl.zeros((BLOCK_N,), dtype=ACC_TYPE)
else:
db_acc = None
acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE)
iters = tl.cdiv(end_idx - start_idx, BLOCK_M)
for i in range(0, iters):
M_mask = (i * BLOCK_M + M_block) < end_idx
if NO_K_MASK:
xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :])
else:
xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])
if NO_N_MASK:
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None])
else:
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])
acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32)
xt_blk_ptrs += BLOCK_M * stride_xm
dy_blk_ptrs += BLOCK_M * stride_dym
if compute_bias:
db_acc += tl.sum(dy, axis=0)
DW_blk_ptrs = (
DW_ptr
+ E_idx * stride_dwe
+ K_block[:, None] * stride_dwk
+ N_block[None, :] * stride_dwn
)
acc = acc.to(DW_blk_ptrs.dtype.element_ty)
tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :])
if compute_bias:
Db_blk_ptrs = Db_ptr + E_idx * stride_dbe + N_block * stride_dbn
tl.store(Db_blk_ptrs, db_acc, mask=N_mask)
def _config_grouping():
return [
triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
]
def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None):
N = sorted_expert_idxs.size(0)
K = A.size(1)
assert A.size(0) * fan_out == N
if out is not None:
Y = out
else:
Y = torch.empty((N, K), dtype=A.dtype, device=A.device)
group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs)
return Y
@torch.library.custom_op("scattermoe::group", mutates_args={"Y"})
def group_compileable(
A: torch.Tensor,
K: int,
N: int,
Y: torch.Tensor,
coeff: Optional[torch.Tensor],
has_coeff: bool,
fan_out: int,
sorted_expert_idxs: torch.Tensor,
) -> None:
def grid(META):
grid_num = (triton.cdiv(META["N"], META["BLOCK_N"]),)
return grid_num
_group[grid](
# A_ptr, stride_an, stride_ai,
A,
A.stride(0),
A.stride(1),
has_coeff,
coeff,
fan_out,
# Y_ptr, stride_yn, stride_yk,
Y,
Y.stride(0),
Y.stride(1),
# grouped_idx_ptr,
sorted_expert_idxs,
# N: tl.constexpr, K: tl.constexpr,
N,
K,
)
@triton.autotune(configs=_config_grouping(), key=["K"])
@triton.heuristics({"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0})
@triton.jit
def _group(
src_ptr,
stride_sn,
stride_sk,
has_coeff: tl.constexpr,
coeff_ptr,
FAN_OUT: tl.constexpr,
tgt_ptr,
stride_tn,
stride_ti,
grouped_idx_ptr,
N,
K: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
NO_K_MASK: tl.constexpr,
):
pid = tl.program_id(axis=0)
N_block_id = pid
N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_blk < N
N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N)
N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0)
K_blk = tl.arange(0, BLOCK_K)
src_blk_ptrs = (
src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk
)
tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti
if has_coeff:
c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None]
iters = tl.cdiv(K, BLOCK_K)
for i in range(0, iters):
if NO_K_MASK or i < iters - 1:
block = tl.load(src_blk_ptrs, mask=N_mask[:, None])
if has_coeff:
block *= c
tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None])
else:
K_mask = (i * BLOCK_K + K_blk) < K
mask = N_mask[:, None] & K_mask[None, :]
block = tl.load(src_blk_ptrs, mask=mask)
if has_coeff:
block *= c
tl.store(tgt_blk_ptrs, block, mask=mask)
src_blk_ptrs += BLOCK_K * stride_sk
tgt_blk_ptrs += BLOCK_K * stride_ti

View File

@@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/shawntan/scattermoe
# Copyright (c) Shawn Tan and ScatterMoE Contributors
# Licensed under the Apache License, Version 2.0
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
import torch
import triton
import triton.language as tl
@triton.jit
def _single2scatter(
X_ptr,
stride_xm,
stride_xk,
W_ptr,
stride_we,
stride_wk,
stride_wn,
Y_ptr,
stride_ym,
stride_yn,
expert_idxs_ptr,
FAN_OUT: tl.constexpr,
K: tl.constexpr,
N: tl.constexpr,
E: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
pid0 = tl.program_id(axis=0)
pid1 = tl.program_id(axis=1)
N_block_id = pid0
if FAN_OUT == 1:
in_idx = pid1
else:
in_idx = 0
out_idx = pid1
K_block = tl.arange(0, BLOCK_K)
N_block = tl.max_contiguous(
tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N),
BLOCK_N,
)
E_idx = tl.load(expert_idxs_ptr + pid1)
X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
W_blk_ptrs = (
W_ptr
+ E_idx * stride_we
+ K_block[:, None] * stride_wk
+ N_block[None, :] * stride_wn
)
N_mask = N_block < N
acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
for _K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
K_mask = K_block < K
x = tl.load(X_blk_ptrs, mask=K_mask[:, None], other=0.0)
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :], other=0.0)
acc += tl.sum(x * w, axis=0)[None, :]
X_blk_ptrs += BLOCK_K * stride_xk
W_blk_ptrs += BLOCK_K * stride_wk
K_block += BLOCK_K
Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
tl.store(Y_blk_ptrs, acc, mask=N_mask[None, :])
def single2scatter(X, W, expert_idxs):
E, xdim, ydim = W.size()
k = expert_idxs.size(1)
assert X.size(0) == k or X.size(0) == 1
Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
BLOCK_N = 128
BLOCK_K = 128
grid = triton.cdiv(ydim, BLOCK_N), k
_single2scatter[grid](
X,
X.stride(0),
X.stride(1),
W,
W.stride(0),
W.stride(1),
W.stride(2),
Y,
Y.stride(0),
Y.stride(1),
expert_idxs,
FAN_OUT=Y.size(0) // X.size(0),
K=xdim,
N=ydim,
E=E,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
ACC_TYPE=tl.float32,
)
return Y

View File

@@ -0,0 +1,439 @@
# SPDX-License-Identifier: Apache-2.0
#
# Original work Copyright (c) Shawn Tan and ScatterMoE Contributors
# Adapted from https://github.com/shawntan/scattermoe
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
#
# Modifications and LoRA adaptation Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
ScatterMoE layer replacements for HuggingFace MoE architectures.
Provides drop-in forward replacements that use ScatterMoE kernels for
acceleration. When used via the HF ``kernels`` library
(``replace_kernel_forward_from_hub``), these classes replace the forward
method of the original MoE block.
LoRA support
------------
When peft wraps parameters via ``target_parameters``, the ``self.experts``
submodule becomes a chain of ``ParamWrapper`` objects and the ``self.gate``
router may also become a ``ParamWrapper``. The ``HFScatterMoEGatedMLP``
forward detects this and automatically:
1. Unwraps ``self.gate`` to the base router, applying gate LoRA delta
2. Unwraps ``self.experts`` to the base ``OlmoeExperts`` module
3. Extracts LoRA A/B weights and scaling from each wrapper
4. Converts B layout from peft rank-major to scattermoe expert-major
5. Routes to ``parallel_linear_lora`` for fused LoRA computation
6. Passes through ``self.shared_expert`` / ``self.shared_expert_gate``
(peft wraps their linear layers with standard LoRA, no special handling)
"""
import torch
from torch import nn
from torch.nn import functional as F
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora
# =============================================================================
# LoRA layout conversion utilities (peft <-> scattermoe)
# =============================================================================
def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
"""Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe
expert-major ``[N, r*E]``.
peft reshapes B to ``[out, r, E]`` (rank-major).
scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major).
"""
N = peft_B.shape[0]
return (
peft_B.reshape(N, rank, num_experts)
.permute(0, 2, 1)
.contiguous()
.reshape(N, num_experts * rank)
)
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Convert peft LoRA weights to scattermoe layout (with A<->B swap).
peft operates on the parameter in its native storage layout ``[E, dim1, dim2]``
where ``in_features=dim1, out_features=dim2``. ScatterMoE transposes the
parameter (``W = param.transpose(2, 1)``) giving ``[E, dim2, dim1]`` with
``K=dim2, N=dim1``. Because of this transposition, peft's A and B roles
are swapped relative to scattermoe's convention.
peft gives:
lora_A ``[r*E, dim1]``, lora_B ``[dim2, r*E]``
scattermoe needs:
lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]``
This function swaps A<->B and converts B from rank-major to expert-major.
Uses vectorized tensor operations (no Python loop over experts).
Works for **both** gate_up_proj and down_proj since the transposition
issue is the same for any parameter.
"""
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
dim1 = peft_A.shape[1] # peft in_features -> scattermoe N
dim2 = peft_B_em.shape[0] # peft out_features -> scattermoe K
# smoe_A: per expert, transpose B_e [dim2, r] -> [r, dim2]
# [dim2, E*r] -> [dim2, E, r] -> [E, r, dim2] -> [E*r, dim2]
smoe_A = (
peft_B_em.reshape(dim2, num_experts, rank)
.permute(1, 2, 0)
.contiguous()
.reshape(rank * num_experts, dim2)
)
# smoe_B: per expert, transpose A_e [r, dim1] -> [dim1, r]
# [E*r, dim1] -> [E, r, dim1] -> [dim1, E, r] -> [dim1, E*r]
smoe_B = (
peft_A.reshape(num_experts, rank, dim1)
.permute(2, 0, 1)
.contiguous()
.reshape(dim1, num_experts * rank)
)
return smoe_A, smoe_B
def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Deprecated alias for :func:`peft_lora_to_scattermoe`."""
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
# =============================================================================
# ParamWrapper unwrapping
# =============================================================================
def _unwrap_gate_lora(gate_module):
"""Unwrap peft ``ParamWrapper`` on the router gate.
When peft targets ``gate.weight``, ``self.gate`` becomes::
ParamWrapper(weight)
-> base_layer: OlmoeTopKRouter (the real module)
This function detects the wrapping and returns the base router, its
weight tensor, and an optional LoRA delta tensor.
Returns:
(base_gate, gate_weight, gate_lora_delta_or_None)
``base_gate`` is the original router module (with ``.top_k``,
``.num_experts``, ``.norm_topk_prob``).
``gate_weight`` is the base router weight (may be a DTensor under FSDP).
``gate_lora_delta_or_None`` is the LoRA delta tensor if LoRA is active,
else ``None``. Kept separate to avoid mixing DTensor + Tensor in an add.
"""
if hasattr(gate_module, "base_layer") and hasattr(gate_module, "lora_A"):
base_gate = gate_module.base_layer
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gate_module)
if lora_A is not None:
# gate weight: [num_experts, hidden_size]
# lora_A: [r, hidden_size], lora_B: [num_experts, r]
# delta = scaling * B @ A = [num_experts, hidden_size]
delta = scaling * (lora_B @ lora_A)
return base_gate, base_gate.weight, delta
else:
return base_gate, base_gate.weight, None
else:
# No wrapping — gate is the original module
return gate_module, gate_module.weight, None
def _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling):
"""Convert peft LoRA weights to scattermoe layout."""
smoe_A, smoe_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank)
return (smoe_A, smoe_B, scaling)
def _unwrap_experts_lora(experts_module):
"""Walk a peft ``ParamWrapper`` chain on ``self.experts``.
When peft targets ``experts.gate_up_proj`` and ``experts.down_proj`` via
``target_parameters``, ``self.experts`` becomes a nested chain::
ParamWrapper(down_proj)
-> base_layer: ParamWrapper(gate_up_proj)
-> base_layer: OlmoeExperts (the real module)
This function walks the chain, collects LoRA params keyed by
``parameter_name``, and returns the base experts module.
Returns:
(base_experts, gup_lora, down_lora)
Each ``*_lora`` is either ``(smoe_A, smoe_B, scaling)`` or ``None``.
A/B are already in scattermoe layout.
"""
# Collect ParamWrapper layers by their parameter_name
wrappers = {}
module = experts_module
while hasattr(module, "base_layer") and hasattr(module, "lora_A"):
param_name = getattr(module, "parameter_name", None)
if param_name is not None:
wrappers[param_name] = module
module = module.base_layer
base_experts = module
if not wrappers:
return base_experts, None, None
# Determine num_experts from base module
num_experts = getattr(base_experts, "num_experts", None)
if num_experts is None:
# Fallback: infer from parameter shape
gup = getattr(base_experts, "gate_up_proj", None)
if gup is not None:
num_experts = gup.shape[0]
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
gup_lora = None
gup_wrapper = wrappers.get("gate_up_proj")
if gup_wrapper is not None:
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper)
if lora_A is not None:
rank = lora_A.shape[0] // num_experts
gup_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)
# Extract down_proj LoRA (needs A<->B swap due to transposition)
down_lora = None
down_wrapper = wrappers.get("down_proj")
if down_wrapper is not None:
lora_A, lora_B, scaling = get_lora_params_from_wrapper(down_wrapper)
if lora_A is not None:
rank = lora_A.shape[0] // num_experts
down_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)
return base_experts, gup_lora, down_lora
# =============================================================================
# Layer classes
# =============================================================================
class ScatterMoEGatedMLP(nn.Module):
def forward(self, layer_input):
"""
Forward pass of the mixture of experts layer.
Args:
layer_input (Tensor):
Input tensor.
Returns:
Tensor:
Output tensor.
"""
bsz, length, emb_size = layer_input.size()
layer_input = layer_input.reshape(-1, emb_size)
# compute the top_k routing decision
router_logits = self.router.layer(layer_input)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(
routing_weights, self.router.top_k, dim=-1
)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(layer_input.dtype)
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
selected_experts, num_experts=self.router.num_experts
)
# compute experts
gates, h = parallel_linear(
layer_input,
self.input_linear.weight.transpose(2, 1),
self.router.top_k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
grouped_in=False,
grouped_out=True,
).chunk(2, dim=-1)
h = self.activation(gates) * h
layer_output = parallel_linear(
h,
self.output_linear.weight.transpose(2, 1),
1,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
grouped_in=True,
grouped_out=False,
gates=routing_weights,
)
layer_output = layer_output.view(bsz, length, emb_size)
return layer_output
class HFScatterMoEGatedMLP(nn.Module):
"""
ScatterMoE-accelerated forward pass for HF MoEs (OLMoE / Qwen2MoE).
Used as a kernel layer via the HF ``kernels`` library. The ``forward``
method replaces the original ``OlmoeSparseMoeBlock.forward``.
Supports both full-parameter training and LoRA fine-tuning:
* **Full-param**: uses ``parallel_linear`` (base ScatterMoE kernel)
* **LoRA**: detects peft ``ParamWrapper`` on ``self.experts``, extracts
adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
"""
@staticmethod
def forward(self: nn.Module, layer_input: torch.Tensor):
"""
Forward pass using ScatterMoE kernels.
Args:
self: The MoeSparseMoeBlock module containing:
- self.gate: Router (or peft ParamWrapper wrapping it)
- self.experts: Experts module (or peft ParamWrapper chain)
- self.shared_expert: Optional shared expert (e.g. Qwen2MoE)
- self.shared_expert_gate: Optional shared expert gate
layer_input: Input tensor [batch_size, seq_len, hidden_size]
Returns:
Tensor: [batch_size, seq_len, hidden_size]
"""
batch_size, sequence_length, hidden_dim = layer_input.shape
hidden_states_flat = layer_input.view(-1, hidden_dim)
# ====================================================================
# Shared Expert (if present, e.g. Qwen2MoE)
# ====================================================================
# peft wraps individual linear layers inside shared_expert with
# standard LoRA — calling forward() handles this transparently.
if hasattr(self, "shared_expert") and self.shared_expert is not None:
shared_expert_output = self.shared_expert(hidden_states_flat)
# shared_expert_gate may also be peft-wrapped (standard LoRA
# on nn.Linear), its forward() applies LoRA automatically.
shared_expert_gate_output = F.sigmoid(
self.shared_expert_gate(hidden_states_flat)
)
shared_expert_output = shared_expert_output * shared_expert_gate_output
else:
shared_expert_output = None
# ====================================================================
# Router Computation (with optional gate LoRA)
# ====================================================================
base_gate, gate_weight, gate_lora_delta = _unwrap_gate_lora(self.gate)
router_logits = F.linear(hidden_states_flat, gate_weight)
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states_flat, gate_lora_delta
)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
top_k = base_gate.top_k
num_experts = base_gate.num_experts
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
if base_gate.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states_flat.dtype)
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
selected_experts, num_experts=num_experts
)
# ====================================================================
# Detect LoRA (peft ParamWrapper) and extract adapter weights
# ====================================================================
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
# ====================================================================
# Gate + Up projection
# ====================================================================
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
gup = parallel_linear_lora(
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
lora_A=gup_A,
lora_B=gup_B,
scaling=gup_scaling,
grouped_in=False,
grouped_out=True,
use_fused_dX=True,
use_fused_gather=True,
)
else:
gup = parallel_linear(
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
grouped_in=False,
grouped_out=True,
)
gates, h = gup.chunk(2, dim=-1)
h = experts.act_fn(gates) * h
# ====================================================================
# Down projection
# ====================================================================
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
if down_lora is not None:
down_A, down_B, down_scaling = down_lora
expert_output = parallel_linear_lora(
h,
down_W,
1,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
lora_A=down_A,
lora_B=down_B,
scaling=down_scaling,
gates=routing_weights,
grouped_in=True,
grouped_out=False,
use_fused_dX=True,
use_fused_gather=True,
)
else:
expert_output = parallel_linear(
h,
down_W,
1,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
grouped_in=True,
grouped_out=False,
gates=routing_weights,
)
# ====================================================================
# Combine with shared expert and reshape
# ====================================================================
if shared_expert_output is not None:
expert_output = expert_output + shared_expert_output
expert_output = expert_output.view(batch_size, sequence_length, hidden_dim)
return expert_output

View File

@@ -0,0 +1,99 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
ParallelExperts module with LoRA support.
Provides a drop-in replacement for ScatterMoE's ParallelExperts that
uses the fused LoRA kernel when adapter weights are attached.
"""
from typing import Optional
import torch
import torch.nn as nn
from .parallel_linear_lora import parallel_linear_lora
class ParallelExperts(nn.Module):
"""
Parallel Experts with fused LoRA support.
Drop-in replacement for the original ParallelExperts. When LoRA parameters
are attached via set_lora(), the forward pass uses a fused kernel:
Y = X @ W + scaling * (X @ A^T) @ B^T
"""
def __init__(
self,
num_experts: int,
input_size: int,
output_size: int,
bias: bool = False,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
if bias:
self.bias = nn.Parameter(torch.empty(num_experts, output_size))
else:
self.bias = None
self.num_experts = num_experts
self.input_size = input_size
self.output_size = output_size
self._lora_A: torch.Tensor | None = None
self._lora_B: torch.Tensor | None = None
self._lora_scaling: float | None = None
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, std=0.02)
if self.bias is not None:
nn.init.zeros_(self.bias)
def extra_repr(self) -> str:
return (
f"num_experts={self.num_experts}, "
f"input_size={self.input_size}, "
f"output_size={self.output_size}"
)
def set_lora(self, lora_A: torch.Tensor, lora_B: torch.Tensor, scaling: float):
"""Attach LoRA parameters for fused computation."""
self._lora_A = lora_A
self._lora_B = lora_B
self._lora_scaling = scaling
def clear_lora(self):
"""Remove LoRA parameters."""
self._lora_A = None
self._lora_B = None
self._lora_scaling = None
def forward(
self,
inputs: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
gates: Optional[torch.Tensor] = None,
grouped_in: bool = False,
grouped_out: bool = False,
) -> torch.Tensor:
return parallel_linear_lora(
inputs,
self.weight.permute(0, 2, 1), # [E, input, output]
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
lora_A=self._lora_A,
lora_B=self._lora_B,
scaling=self._lora_scaling if self._lora_scaling is not None else 1.0,
expert_biases=self.bias,
gates=gates,
grouped_in=grouped_in,
grouped_out=grouped_out,
)

View File

@@ -0,0 +1,253 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/shawntan/scattermoe
# Copyright (c) Shawn Tan and ScatterMoE Contributors
# Licensed under the Apache License, Version 2.0
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
from typing import Optional
import torch
import torch.nn as nn
from . import kernels
@torch.library.custom_op("scattermoe::bincount", mutates_args={})
def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor:
return x.bincount(minlength=minlength)
@compileable_bincount.register_fake
def _(x: torch.Tensor, minlength: int) -> torch.Tensor:
return torch.empty(minlength, dtype=torch.long, device=x.device)
@torch.compile
def flatten_sort_count(expert_idxs: torch.Tensor, num_experts: int):
with torch.no_grad():
flattened_expert_idxs = expert_idxs.flatten()
sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs)
expert_counts = compileable_bincount(
flattened_expert_idxs, minlength=num_experts
)
expert_offsets = expert_counts.cumsum(-1)
return sorted_expert_idxs, sorted_scattered_idxs, expert_offsets
class ParallelLinear(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: torch.Tensor,
expert_weights: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
expert_biases: Optional[torch.Tensor] = None,
gates: Optional[torch.Tensor] = None,
grouped_in: bool = False,
grouped_out: bool = False,
):
with torch.device(x.device):
output = kernels.ops.scatter2scatter(
X=x,
W=expert_weights,
b=expert_biases,
k=k,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
x_grouped=grouped_in,
y_grouped=grouped_out,
)
if gates is not None:
output_expanded = output.view(
gates.size(0), gates.size(1), output.size(-1)
)
output = (gates.unsqueeze(1) @ output_expanded).squeeze(1)
else:
output_expanded = None
ctx.save_for_backward(
x,
expert_weights,
expert_biases,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded,
)
ctx.grouped_in = grouped_in
ctx.grouped_out = grouped_out
ctx.k = k
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
with torch.device(grad_out.device):
(
x,
expert_weights,
expert_biases,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded,
) = ctx.saved_tensors
k = ctx.k
grouped_in = ctx.grouped_in
grouped_out = ctx.grouped_out
if gates is not None:
# calculate gates gradient
# d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1)
gates_flat = gates.flatten()
gate_fan = gates.size(1)
grouped_grad_out = output_expanded.flatten(
0, 1
) # reuse expanded buffer later
else:
d_gates = None
gates_flat = None
gate_fan = 1
grouped_grad_out = None
if grouped_out:
grouped_grad_out = grad_out
else:
grouped_grad_out = kernels.ops.group(
grad_out,
sorted_scattered_idxs,
fan_out=gate_fan,
coeff=gates_flat,
out=grouped_grad_out,
)
if grouped_in:
grouped_x = x
d_expanded_input = None
else:
grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k)
d_expanded_input = grouped_x
d_weights, d_biases = kernels.ops.group_bwd_W(
DY=grouped_grad_out,
X=grouped_x,
expert_offsets=expert_offsets,
E=expert_weights.size(0),
has_bias=expert_biases is not None,
)
d_expanded_input = kernels.ops.scatter2scatter(
X=grouped_grad_out,
x_grouped=True,
W=expert_weights.permute(0, 2, 1),
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
y_grouped=grouped_in,
out=d_expanded_input, # Reuse grouped_x buffer
)
if k == 1:
d_input = d_expanded_input
else:
d_input = d_expanded_input.view(
x.size(0), k, d_expanded_input.size(-1)
).sum(-2)
return (
# x, expert_weights,
d_input,
d_weights,
# k, sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,
None,
None,
None,
None,
# bias, gates
d_biases,
d_gates,
# grouped_in, grouped_out,
None,
None,
)
def parallel_linear(
inputs,
expert_weights,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
expert_biases=None,
gates=None,
grouped_in=False,
grouped_out=False,
):
results = ParallelLinear.apply(
inputs,
expert_weights,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
expert_biases,
gates,
grouped_in,
grouped_out,
)
return results
class ParallelExperts(nn.Module):
def __init__(self, num_experts, input_size, output_size, bias=False) -> None:
super().__init__()
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
if bias:
self.bias = nn.Parameter(torch.empty(num_experts, output_size))
else:
self.bias = None
self.num_experts = num_experts
self.input_size = input_size
self.output_size = output_size
self.reset_parameters()
def extra_repr(self):
return "num_experts={}, input_size={}, output_size={}".format(
self.num_experts, self.input_size, self.output_size
)
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, std=0.02)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(
self,
inputs,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates=None,
grouped_in=False,
grouped_out=False,
):
results = parallel_linear(
inputs,
self.weight.permute(0, 2, 1),
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
expert_biases=self.bias,
gates=gates,
grouped_in=grouped_in,
grouped_out=grouped_out,
)
return results

View File

@@ -0,0 +1,480 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
ScatterMoE + LoRA Autograd Function
====================================
Provides the autograd function and Python interface for fused ScatterMoE + LoRA.
Key design for LoRA training:
- Expert weights W are FROZEN (no gradient computed for W).
- Only LoRA adapter weights (A, B) receive gradients.
- The input gradient dX is still computed (needed for upstream layers).
- This avoids the expensive group_bwd_W computation entirely.
Forward:
Y = X @ W + scaling * (X @ A^T) @ B^T
Backward (W frozen):
dX = dY @ W^T + scaling * (dY @ B) @ A (via scatter2scatter for base, separate for LoRA)
dA = scaling * (dY @ B)^T @ X (per-expert, on grouped data)
dB = scaling * dY^T @ (X @ A^T) (per-expert, on grouped data)
"""
from typing import Optional
import torch
from .kernels import ops as base_ops
from .kernels.lora_ops import (
group_bwd_lora,
group_bwd_lora_fused,
scatter2scatter_lora,
scatter2scatter_lora_dX,
)
class ScatterMoELoRA(torch.autograd.Function):
"""
Autograd function for fused ScatterMoE + LoRA with frozen expert weights.
This function is optimized for the LoRA fine-tuning scenario where:
- Expert weights W are frozen (requires_grad=False)
- Only LoRA A and B matrices receive gradients
- Input gradients are computed for upstream layer backprop
"""
@staticmethod
def forward(
ctx,
x: torch.Tensor,
expert_weights: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
lora_A: torch.Tensor,
lora_B: torch.Tensor,
scaling: float,
expert_biases: Optional[torch.Tensor] = None,
gates: Optional[torch.Tensor] = None,
grouped_in: bool = False,
grouped_out: bool = False,
use_fused_dX: bool = False,
use_fused_gather: bool = False,
):
with torch.device(x.device):
# Fused forward: Y = X @ W + scaling * (X @ A^T) @ B^T
output = scatter2scatter_lora(
X=x,
W=expert_weights,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=k,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
b=expert_biases,
x_grouped=grouped_in,
y_grouped=grouped_out,
)
# Handle gating (weighted combination of top-k expert outputs)
if gates is not None:
output_expanded = output.view(
gates.size(0), gates.size(1), output.size(-1)
)
output = (gates.unsqueeze(1) @ output_expanded).squeeze(1)
else:
output_expanded = None
ctx.save_for_backward(
x,
lora_A,
lora_B,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded,
)
# Store frozen weights as plain Python attributes instead of
# save_for_backward. This avoids:
# 1. Version-check conflicts with FSDP unshard/reshard
# 2. Pinning all-gathered parameters via saved_tensors hooks
# 3. Interfering with activation offloading pack/unpack hooks
# Safe because expert_weights are frozen (requires_grad=False).
ctx.expert_weights = expert_weights
ctx.expert_biases = expert_biases
ctx.grouped_in = grouped_in
ctx.grouped_out = grouped_out
ctx.k = k
ctx.scaling = scaling
ctx.use_fused_dX = use_fused_dX
ctx.use_fused_gather = use_fused_gather
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
with torch.device(grad_out.device):
(
x,
lora_A,
lora_B,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded,
) = ctx.saved_tensors
expert_weights = ctx.expert_weights
k = ctx.k
scaling = ctx.scaling
grouped_in = ctx.grouped_in
grouped_out = ctx.grouped_out
E = expert_weights.size(0)
# ------------------------------------------------------------------
# Gate gradients (if using top-k gating with routing weights)
# ------------------------------------------------------------------
if gates is not None:
# d_gates[t, j] = output_expanded[t, j, :] . grad_out[t, :]
d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1)
gates_flat = gates.flatten()
gate_fan = gates.size(1)
# Reuse output_expanded buffer for grouped_grad_out
grouped_grad_out = output_expanded.flatten(0, 1)
else:
d_gates = None
gates_flat = None
gate_fan = 1
grouped_grad_out = None
# ------------------------------------------------------------------
# LoRA gradients (dA, dB) and setup for dX
# ------------------------------------------------------------------
# Fused gather uses sorted_scattered_idxs for indirect X access
# in the Triton kernel, avoiding the group(x) allocation.
#
# can_fuse_gather: X is ungrouped and not too large for scatter loads
# - When gates is None and grouped_out=False: both DY and X ungrouped
# - When grouped_out=True (gate_up_proj): DY already grouped, X ungrouped
# -> use dy_grouped=True in the fused kernel
M_total = sorted_scattered_idxs.size(0)
K_dim = x.size(-1)
N_dim = expert_weights.size(-1)
fuse_gather_workload = M_total * max(K_dim, N_dim)
_FUSE_GATHER_THRESHOLD = 2**24 # ~16M elements
can_fuse_gather = (
ctx.use_fused_gather
and not grouped_in # X must be ungrouped for scatter access
and gates is None # gate coeff requires multiplicative gather
and fuse_gather_workload < _FUSE_GATHER_THRESHOLD
)
if can_fuse_gather:
# ------------------------------------------------------------------
# Fused path: skip group(x) entirely
# ------------------------------------------------------------------
d_expanded_input = None
d_lora_A, d_lora_B = group_bwd_lora_fused(
DY=grad_out,
X=x,
lora_A=lora_A,
lora_B=lora_B,
expert_offsets=expert_offsets,
sorted_scattered_idxs=sorted_scattered_idxs,
E=E,
k=k,
scaling=scaling,
dy_grouped=grouped_out,
)
# Prepare grouped_grad_out for the dX path (needed by both
# the fused dX kernel when grouped_out=True, and the non-fused path)
if grouped_out:
grouped_grad_out = grad_out
elif not ctx.use_fused_dX:
grouped_grad_out = base_ops.group(
grad_out,
sorted_scattered_idxs,
fan_out=gate_fan,
coeff=gates_flat,
out=grouped_grad_out,
)
else:
# ------------------------------------------------------------------
# Original path: explicit group() calls
# ------------------------------------------------------------------
if grouped_out:
grouped_grad_out = grad_out
else:
grouped_grad_out = base_ops.group(
grad_out,
sorted_scattered_idxs,
fan_out=gate_fan,
coeff=gates_flat,
out=grouped_grad_out,
)
if grouped_in:
grouped_x = x
d_expanded_input = None
else:
grouped_x = base_ops.group(x, sorted_scattered_idxs, fan_out=k)
d_expanded_input = grouped_x # Will be overwritten; reuse buffer
d_lora_A, d_lora_B = group_bwd_lora(
DY=grouped_grad_out,
X=grouped_x,
lora_A=lora_A,
lora_B=lora_B,
expert_offsets=expert_offsets,
E=E,
scaling=scaling,
)
# ------------------------------------------------------------------
# Input gradient: dX = dY @ W^T + scaling * (dY @ B) @ A
# ------------------------------------------------------------------
if ctx.use_fused_dX:
if can_fuse_gather and not grouped_out:
# Fully fused: read ungrouped DY via scatter pattern
d_expanded_input = scatter2scatter_lora_dX(
DY=grad_out,
W=expert_weights,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
dy_grouped=False,
dx_grouped=grouped_in,
out=d_expanded_input,
)
else:
# Fused dX only: read from pre-grouped DY
d_expanded_input = scatter2scatter_lora_dX(
DY=grouped_grad_out,
W=expert_weights,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
dy_grouped=True,
dx_grouped=grouped_in,
out=d_expanded_input,
)
else:
# Original path: separate base scatter2scatter + LoRA Python loop
d_expanded_input = base_ops.scatter2scatter(
X=grouped_grad_out,
x_grouped=True,
W=expert_weights.permute(0, 2, 1), # [E, N, K]
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
y_grouped=grouped_in,
out=d_expanded_input,
)
# LoRA part: dX_lora = scaling * (dY @ B) @ A
if scaling != 0.0:
d_input_lora_grouped = _compute_lora_input_grad(
grouped_grad_out,
lora_A,
lora_B,
expert_offsets,
E,
scaling,
)
if grouped_in:
d_expanded_input.add_(d_input_lora_grouped)
else:
# Scatter-add LoRA gradient directly into d_expanded_input.
# Avoids allocating a zeros_like + add result
d_expanded_input[sorted_scattered_idxs] += d_input_lora_grouped
# Reduce over top-k if k > 1
if k == 1:
d_input = d_expanded_input
else:
d_input = d_expanded_input.view(
x.size(0), k, d_expanded_input.size(-1)
).sum(-2)
# W is frozen during LoRA training -- skip weight gradient
d_weights = (
torch.zeros_like(expert_weights)
if expert_weights.requires_grad
else None
)
d_biases = None
return (
d_input,
d_weights,
None,
None,
None,
None, # k, sorted indices, offsets
d_lora_A,
d_lora_B,
None, # lora_A, lora_B, scaling
d_biases,
d_gates,
None,
None, # grouped_in, grouped_out
None, # use_fused_dX
None, # use_fused_gather
)
def _compute_lora_input_grad(
grouped_grad_out: torch.Tensor,
lora_A: torch.Tensor,
lora_B: torch.Tensor,
expert_offsets: torch.Tensor,
E: int,
scaling: float,
) -> torch.Tensor:
"""
Compute the LoRA contribution to the input gradient:
dX_lora = scaling * (dY @ B) @ A
Uses PyTorch ops on expert-grouped data.
Each expert e: dX_e = scaling * (dY_e @ B_e) @ A_e
"""
R = lora_A.size(0) // E
K = lora_A.size(1)
M_total = grouped_grad_out.size(0)
d_input_lora = torch.zeros(
(M_total, K), device=grouped_grad_out.device, dtype=grouped_grad_out.dtype
)
compute_dtype = grouped_grad_out.dtype
prev_offset = 0
for e in range(E):
curr_offset = expert_offsets[e].item()
if curr_offset > prev_offset:
dy_e = grouped_grad_out[prev_offset:curr_offset] # [M_e, N]
a_e = lora_A[e * R : (e + 1) * R, :].to(compute_dtype) # [r, K]
b_e = lora_B[:, e * R : (e + 1) * R].to(compute_dtype) # [N, r]
# dX_e = scaling * (dY_e @ B_e) @ A_e
dy_b = dy_e @ b_e # [M_e, r]
dx_e = scaling * (dy_b @ a_e) # [M_e, K]
d_input_lora[prev_offset:curr_offset] = dx_e
prev_offset = curr_offset
return d_input_lora
# =============================================================================
# Helper: Extract LoRA params from PEFT ParamWrapper
# =============================================================================
def get_lora_params_from_wrapper(module) -> tuple:
"""
Extract LoRA parameters from a PEFT ParamWrapper.
Returns:
(lora_A, lora_B, scaling) if LoRA is active, else (None, None, None)
"""
if not hasattr(module, "lora_A") or not hasattr(module, "lora_B"):
return None, None, None
active_adapters = getattr(module, "active_adapters", ["default"])
if not active_adapters:
return None, None, None
adapter_name = active_adapters[0]
lora_A_dict = getattr(module, "lora_A", {})
lora_B_dict = getattr(module, "lora_B", {})
scaling_dict = getattr(module, "scaling", {})
if adapter_name not in lora_A_dict:
return None, None, None
lora_A = lora_A_dict[adapter_name].weight
lora_B = lora_B_dict[adapter_name].weight
scaling = scaling_dict[adapter_name]
return lora_A, lora_B, scaling
# =============================================================================
# Drop-in replacement for parallel_linear
# =============================================================================
def parallel_linear_lora(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
lora_A: Optional[torch.Tensor] = None,
lora_B: Optional[torch.Tensor] = None,
scaling: float = 1.0,
expert_biases: Optional[torch.Tensor] = None,
gates: Optional[torch.Tensor] = None,
grouped_in: bool = False,
grouped_out: bool = False,
use_fused_dX: bool = False,
use_fused_gather: bool = False,
):
"""
Drop-in replacement for parallel_linear that supports LoRA.
If lora_A and lora_B are provided, uses fused LoRA kernel.
Otherwise falls back to standard scatter2scatter.
"""
if lora_A is not None and lora_B is not None:
return ScatterMoELoRA.apply(
inputs,
expert_weights,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
lora_A,
lora_B,
scaling,
expert_biases,
gates,
grouped_in,
grouped_out,
use_fused_dX,
use_fused_gather,
)
else:
from .parallel_experts import ParallelLinear
return ParallelLinear.apply(
inputs,
expert_weights,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
expert_biases,
gates,
grouped_in,
grouped_out,
)

View File

@@ -1,5 +1,7 @@
from pathlib import Path
from kernels import (
LayerRepository,
LocalLayerRepository,
Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
@@ -19,16 +21,19 @@ class KernelsPlugin(BasePlugin):
self._kernelize_model(cfg.model_config_type)
def _register_kernels(self):
plugin_root = Path(__file__).parent
register_kernel_mapping(
{
"HFScatterMoEParallelExperts": {
"cuda": {
Mode.TRAINING: LayerRepository(
repo_id="axolotl-ai-co/scattermoe",
Mode.TRAINING: LocalLayerRepository(
repo_path=plugin_root / "libs" / "scattermoe_lora",
package_name="scattermoe_lora",
layer_name="HFScatterMoEGatedMLP",
),
Mode.INFERENCE: LayerRepository(
repo_id="axolotl-ai-co/scattermoe",
Mode.INFERENCE: LocalLayerRepository(
repo_path=plugin_root / "libs" / "scattermoe_lora",
package_name="scattermoe_lora",
layer_name="HFScatterMoEGatedMLP",
),
},

View File

@@ -6,6 +6,12 @@ See https://github.com/EleutherAI/lm-evaluation-harness
## Usage
There are two ways to use the LM Eval integration:
### 1. Post-Training Evaluation
When training with the plugin enabled, evaluation runs automatically after training completes:
```yaml
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
@@ -16,9 +22,50 @@ lm_eval_tasks:
- arc_easy
lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
# Directory to save evaluation results.
# The final model is loaded from this directory
# unless specified otherwise (see below)
output_dir:
```
Run training as usual:
```bash
axolotl train config.yml
```
### 2. Standalone CLI Evaluation
Evaluate any model directly without training:
```yaml
lm_eval_model: meta-llama/Llama-2-7b-hf
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
lm_eval_tasks:
- gsm8k
- hellaswag
- arc_easy
lm_eval_batch_size: 8
output_dir: ./outputs
```
Run evaluation:
```bash
axolotl lm-eval config.yml
```
## Model Selection Priority
The model to evaluate is selected in the following priority order:
1. **`lm_eval_model`** - Explicit model path or HuggingFace repo (highest priority)
2. **`hub_model_id`** - Trained model pushed to HuggingFace Hub
3. **`output_dir`** - Local checkpoint directory containing trained model weights
## Citation
```bib

View File

@@ -5,7 +5,7 @@ Module for the Plugin for LM Eval Harness
import subprocess # nosec
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lm_eval.cli import build_lm_eval_command
from axolotl.integrations.lm_eval.cli import build_lm_eval_command, get_model_path
from .args import LMEvalArgs as LMEvalArgs
@@ -29,7 +29,7 @@ class LMEvalPlugin(BasePlugin):
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=cfg.lm_eval_model or cfg.hub_model_id,
model=get_model_path(cfg),
):
subprocess.run( # nosec
lm_eval_args,

View File

@@ -13,6 +13,21 @@ import yaml
from axolotl.utils.dict import DictDefault
def get_model_path(cfg: DictDefault) -> str | None:
"""
Determine which model path to use for evaluation.
Priority order (highest to lowest):
1. lm_eval_model - Explicit model path override
2. hub_model_id - Model pushed to HuggingFace Hub
3. None - Falls back to output_dir in build_lm_eval_command
Returns:
Model path string or None to use output_dir fallback
"""
return cfg.lm_eval_model or cfg.hub_model_id or None
def build_lm_eval_command(
tasks: list[str],
bfloat16=True,
@@ -108,7 +123,7 @@ def lm_eval(config: str, cloud: Optional[str] = None):
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=cfg.lm_eval_model or cfg.hub_model_id,
model=get_model_path(cfg),
revision=cfg.revision,
apply_chat_template=cfg.apply_chat_template,
fewshot_as_multiturn=cfg.fewshot_as_multiturn,

View File

@@ -34,7 +34,7 @@ def setup_quantized_meta_for_peft(model: torch.nn.Module):
return self
for param in model.parameters():
if isinstance(param, Params4bit):
if isinstance(param, Params4bit) and param.quant_state is not None:
param.quant_state._orig_to = param.quant_state.to
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)

View File

@@ -172,7 +172,10 @@ class ModelLoader:
# Build the model
PLUGIN_MANAGER.pre_model_load(self.cfg)
self.patch_manager.apply_post_plugin_pre_model_load_patches()
skip_move_to_device = self._build_model()
self.patch_manager.apply_post_model_build_patches(self.model)
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
# Post-build model configuration
@@ -860,6 +863,10 @@ class ModelLoader:
# Make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True
if getattr(self.model, "_moe_experts_quantized", False):
# Parametrized expert tensors dequantize on access — would OOM.
skip_prepare_model_for_kbit_training = True
if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]

View File

@@ -10,6 +10,7 @@ from functools import cached_property
import addict
import transformers
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_flash_attention_utils import is_flash_attn_available
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import (
@@ -117,6 +118,7 @@ class PatchManager:
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
self._apply_tiled_mlp(self.cfg.model_config_type)
self._apply_moe_expert_quantization_patch()
def _apply_transformers_patches(self):
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
@@ -134,6 +136,10 @@ class PatchManager:
patch_prepare_context_parallel_inputs()
def apply_post_model_build_patches(self, model: PreTrainedModel):
"""Apply patches right after model build, before post-load setup."""
self._finalize_moe_expert_quantization(model)
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
self._apply_llama_flash_attn_patches(model)
@@ -169,9 +175,14 @@ class PatchManager:
patch_parallelism_config()
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
from axolotl.monkeypatch.accelerate.fsdp2 import (
patch_accelerate_fsdp2,
patch_tied_keys_for_meta_device,
)
patch_accelerate_fsdp2()
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
patch_tied_keys_for_meta_device()
if self.cfg.rl:
from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2
@@ -328,7 +339,7 @@ class PatchManager:
else:
has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is False:
if has_remote_code and self.cfg.trust_remote_code is not None:
# If explicitly set in YAML, prefer that
has_remote_code = self.cfg.trust_remote_code
@@ -351,15 +362,54 @@ class PatchManager:
if (
self.cfg.fsdp_config
and str(self.cfg.fsdp_version) == "2"
and self.cfg.adapter == "qlora"
and (self.cfg.load_in_4bit or self.cfg.load_in_8bit)
):
from axolotl.monkeypatch.fsdp2_qlora import (
apply_init_dtype_attrs_patch,
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
apply_linear8bitlt_save_patch,
)
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()
apply_init_dtype_attrs_patch()
if self.cfg.load_in_8bit:
apply_linear8bitlt_save_patch()
def _apply_moe_expert_quantization_patch(self):
"""Patch transformers weight loading to quantize MoE expert params on-the-fly."""
if not self.cfg.quantize_moe_experts:
return
from axolotl.monkeypatch.moe_quant import (
patch_moe_quantization_on_load,
patch_peft_target_parameters_matching,
)
patch_moe_quantization_on_load(self.cfg)
patch_peft_target_parameters_matching()
def _finalize_moe_expert_quantization(self, model: PreTrainedModel):
"""Log quantization results and set model flag for downstream use."""
import torch
model._moe_experts_quantized = False
if self.cfg.quantize_moe_experts:
from axolotl.monkeypatch.moe_quant import get_moe_quantized_count
count = get_moe_quantized_count()
if count > 0:
import gc
model._moe_experts_quantized = True
LOG.info(
"Quantized %d MoE expert parameter(s) to %s during model loading",
count,
"4-bit" if self.cfg.load_in_4bit else "8-bit",
)
gc.collect()
torch.cuda.empty_cache()
def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp:
@@ -500,6 +550,7 @@ class PatchManager:
and not self.cfg.trust_remote_code
and not self.cfg.gptq
and self.cfg.flash_attention
and is_flash_attn_available()
and not self.inference
):
# TODO(MengqingCao): split these patches separately

View File

@@ -19,6 +19,11 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
if cfg.processor_type:
processor_cls = getattr(transformers, cfg.processor_type)
# Build common kwargs for processor loading
processor_kwargs = {}
if cfg.revision_of_model:
processor_kwargs["revision"] = cfg.revision_of_model
if cfg.tokenizer_use_mistral_common:
def _patch_mistralcommontokenizer():
@@ -40,6 +45,7 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
if processor_cls == VoxtralProcessor:
return VoxtralProcessor.from_pretrained(
cfg.processor_config,
**processor_kwargs,
)
from axolotl.utils.mistral import Mistral3Processor
@@ -48,10 +54,12 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
tokenizer=tokenizer,
)
processor_kwargs["trust_remote_code"] = cfg.trust_remote_code or False
processor_kwargs["tokenizer"] = tokenizer
processor = processor_cls.from_pretrained(
cfg.processor_config,
trust_remote_code=cfg.trust_remote_code or False,
tokenizer=tokenizer,
**processor_kwargs,
)
# Attempt to load image size from processor if available

View File

@@ -28,7 +28,10 @@ PLUGIN_MANAGER = PluginManager.get_instance()
def modify_tokenizer_files(
tokenizer_path: str, token_mappings: dict[int, str], output_dir: str
tokenizer_path: str,
token_mappings: dict[int, str],
output_dir: str,
revision: str = "main",
) -> str:
"""
Modify tokenizer files to replace added_tokens strings, save to output directory,
@@ -41,6 +44,7 @@ def modify_tokenizer_files(
tokenizer_path: Path or name of the original tokenizer
token_mappings: Dict mapping {token_id (int): new_token_string}
output_dir: Directory to save the modified tokenizer
revision: Model revision/branch/tag/commit to load from (HF Hub)
Returns:
Path to the modified tokenizer directory
@@ -53,7 +57,9 @@ def modify_tokenizer_files(
if is_local_main_process():
# Load the tokenizer
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
temp_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, use_fast=True, revision=revision
)
# Save the tokenizer to the output directory
temp_tokenizer.save_pretrained(tokenizer_dir)
@@ -134,7 +140,10 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
from axolotl.utils.mistral import HFMistralTokenizer
# Load the HF-compatible wrapper around MistralTokenizer
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)
kwargs = {}
if cfg.revision_of_model:
kwargs["revision"] = cfg.revision_of_model
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config, **kwargs)
return tokenizer
@@ -150,6 +159,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
if cfg.tokenizer_legacy is not None:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
if cfg.revision_of_model:
tokenizer_kwargs["revision"] = cfg.revision_of_model
tokenizer_cls = AutoTokenizer
if cfg.tokenizer_type:
@@ -161,8 +172,11 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
# Apply token string overrides if specified
if cfg.added_tokens_overrides:
# Modify tokenizer files and get path to modified tokenizer
modify_kwargs = {"output_dir": cfg.output_dir}
if cfg.revision_of_model:
modify_kwargs["revision"] = cfg.revision_of_model
tokenizer_path = modify_tokenizer_files(
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
tokenizer_path, cfg.added_tokens_overrides, **modify_kwargs
)
tokenizer = tokenizer_cls.from_pretrained(

View File

@@ -111,6 +111,7 @@ class MambaLMHeadModel(nn.Module, GenerationMixin):
self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
**kwargs,
):
if state_dict is None:
state_dict = self.state_dict()

View File

@@ -150,13 +150,17 @@ def get_state_dict(self, model, unwrap=True):
)
elif self.is_fsdp2:
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
from torch.distributed.tensor import DTensor
state_dict = {}
sharded_state_dict = model.state_dict()
for param_name, param in sharded_state_dict.items():
if param.is_cpu:
param = param.to(torch.device("cuda"))
param = param.full_tensor()
if isinstance(param, DTensor):
param = param.full_tensor()
if torch.distributed.get_rank() == 0:
state_dict[param_name] = param.cpu()
torch.distributed.barrier()
@@ -182,10 +186,56 @@ def get_state_dict(self, model, unwrap=True):
return state_dict
def patch_peft_param_wrapper_for_fsdp2():
"""Patch PEFT's _LoraParameterProxy.forward for FSDP2 DTensor compatibility.
PEFT's ParamWrapper applies LoRA via torch.nn.utils.parametrize, which adds
delta_weight to the base weight W inside _LoraParameterProxy.forward().
Under FSDP2, W may be a DTensor (from FSDP unshard) while delta_weight is a
regular Tensor (or vice versa), causing a RuntimeError on mixed types.
This patch promotes the non-DTensor operand to match the DTensor's spec
using DTensor.from_local(), which is free for Replicate placement (just
metadata wrapping, no communication).
"""
from peft.tuners.lora.layer import _LoraParameterProxy
if getattr(_LoraParameterProxy, "_axolotl_fsdp2_patched", False):
return
_original_forward = _LoraParameterProxy.forward
# NOTE: Replaces (not wraps) forward; assumes original is just `W + self.delta_weight`.
def _patched_forward(self, W):
from torch.distributed.tensor import DTensor
delta = self.delta_weight
w_is_dt = isinstance(W, DTensor)
d_is_dt = isinstance(delta, DTensor)
with torch.nn.utils.parametrize.cached():
if w_is_dt == d_is_dt:
return W + delta
if w_is_dt:
return W + DTensor.from_local(delta, W.device_mesh, W.placements)
return DTensor.from_local(W, delta.device_mesh, delta.placements) + delta
_LoraParameterProxy.forward = _patched_forward
_LoraParameterProxy._axolotl_fsdp2_patched = True
LOG.info("Patched PEFT _LoraParameterProxy.forward for FSDP2 DTensor compatibility")
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
"""Helper function to process LoRA modules for FSDP2."""
from peft.tuners.lora.layer import ParamWrapper
from torch.distributed.fsdp import fully_shard
# Skip ParamWrapper — its lora_A/B must not be independently sharded.
# The parent decoder layer's FSDP wrapper handles unsharding them.
# TODO: review if we even need to shard them separately in first place.
if isinstance(module, ParamWrapper):
return False
log_bias_dtype_mismatch = False
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
@@ -327,6 +377,14 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
is_peft_model = isinstance(model, PeftModel)
# Patch PEFT's _LoraParameterProxy for DTensor compatibility if any
# ParamWrapper modules exist (used for target_parameters / 3D expert params).
if is_peft_model:
from peft.tuners.lora.layer import ParamWrapper
if any(isinstance(m, ParamWrapper) for m in model.modules()):
patch_peft_param_wrapper_for_fsdp2()
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
log_bias_dtype_mismatch = False
if auto_wrap_policy is not None:
@@ -376,6 +434,43 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
return model
def patch_tied_keys_for_meta_device():
"""Patch _adjust_tied_keys_with_tied_pointers to skip meta tensors.
Meta tensors all share data_ptr()==0, causing every parameter to be incorrectly
grouped as "tied". Skipping them is safe since they have no real storage.
"""
from collections import defaultdict
from transformers import PreTrainedModel
def _patched_adjust_tied_keys_with_tied_pointers(self, missing_keys):
param_pointers = defaultdict(list)
for param_name, param_value in self.state_dict().items():
if param_value.is_meta:
continue
param_pointers[param_value.data_ptr()].append(param_name)
tied_param_names = [
names
for names in param_pointers.values()
if len(names) > 1
and not any(name in self.all_tied_weights_keys.keys() for name in names)
and not all(name in missing_keys for name in names)
]
tied_weights_keys_by_pointers = {
param_name: group[0]
for group in tied_param_names
for param_name in group[1:]
}
self.all_tied_weights_keys.update(tied_weights_keys_by_pointers)
PreTrainedModel._adjust_tied_keys_with_tied_pointers = (
_patched_adjust_tied_keys_with_tied_pointers
)
def patch_accelerate_fsdp2():
import accelerate

View File

@@ -1,9 +1,10 @@
"""
Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as
our LoRA / QLoRA Triton kernels to work with FSDP2.
Monkeypatch to add Params4bit and Int8Params support to FSDP2. This enables QLoRA + FSDP2
and 8-bit LoRA + FSDP2, as well as our LoRA / QLoRA Triton kernels to work with FSDP2.
This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes
Params4bit parameters.
This patch modifies the _init_sharded_param and init_unsharded_param methods in FSDPParam
to handle bitsandbytes Params4bit and Int8Params parameters, preserving their quantization
metadata through the FSDP2 shard/unshard cycle.
"""
import importlib
@@ -17,6 +18,8 @@ LOG = get_logger(__name__)
def apply_init_sharded_param_patch():
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
if getattr(apply_init_sharded_param_patch, "_axolotl_patched", False):
return
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
# Get original source
@@ -41,9 +44,20 @@ def apply_init_sharded_param_patch():
bnb_quantized=param.bnb_quantized,
)
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
elif isinstance(param, bnb.nn.modules.Int8Params):
self.sharded_param = bnb.nn.modules.Int8Params(
data=sharded_param,
requires_grad=param.requires_grad,
has_fp16_weights=param.has_fp16_weights,
CB=None,
SCB=param.SCB,
)
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
else:
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
self.sharded_param.requires_grad_(param.requires_grad)"""
self.sharded_param = nn.Parameter(
self.to_sharded_dtensor(sharded_param),
requires_grad=param.requires_grad,
)"""
# Apply the replacement
if original_param_creation in original_source:
@@ -73,6 +87,7 @@ def apply_init_sharded_param_patch():
# Replace the method
FSDPParam._init_sharded_param = patched_init_sharded_param
apply_init_sharded_param_patch._axolotl_patched = True
LOG.info("Successfully applied FSDP _init_sharded_param patch")
else:
LOG.warning("Could not find target code for _init_sharded_param patching")
@@ -80,6 +95,8 @@ def apply_init_sharded_param_patch():
def apply_init_unsharded_param_patch():
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
if getattr(apply_init_unsharded_param_patch, "_axolotl_patched", False):
return
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
# Get original source
@@ -105,6 +122,14 @@ def apply_init_unsharded_param_patch():
module=local_tensor.module,
bnb_quantized=local_tensor.bnb_quantized,
)
elif isinstance(local_tensor, bnb.nn.modules.Int8Params):
self._unsharded_param = bnb.nn.modules.Int8Params(
data=unsharded_param,
requires_grad=self.sharded_param.requires_grad,
has_fp16_weights=local_tensor.has_fp16_weights,
CB=unsharded_param,
SCB=local_tensor.SCB,
)
else:
self._unsharded_param = nn.Parameter(
unsharded_param, requires_grad=self.sharded_param.requires_grad
@@ -138,6 +163,74 @@ def apply_init_unsharded_param_patch():
# Replace the method
FSDPParam.init_unsharded_param = patched_init_unsharded_param
apply_init_unsharded_param_patch._axolotl_patched = True
LOG.info("Successfully applied FSDP init_unsharded_param patch")
else:
LOG.warning("Could not find target code for patching")
def apply_linear8bitlt_save_patch():
"""Patch Linear8bitLt._save_to_state_dict to handle DTensor-wrapped Int8Params.
After FSDP2 sharding, Linear8bitLt.weight is a DTensor wrapping Int8Params.
BnB's _save_to_state_dict accesses self.weight.SCB directly, but DTensor
doesn't proxy custom attribute access to its _local_tensor. This patch
temporarily unwraps the DTensor during saving so BnB can find the SCB attribute.
"""
if getattr(apply_linear8bitlt_save_patch, "_axolotl_patched", False):
return
import bitsandbytes as bnb
from torch.distributed.tensor import DTensor
original_save = bnb.nn.Linear8bitLt._save_to_state_dict
def _patched_save_to_state_dict(self, destination, prefix, keep_vars):
# Use _parameters dict directly to bypass nn.Module.__setattr__ type check.
weight = self._parameters["weight"]
unwrapped = False
if isinstance(weight, DTensor) and hasattr(weight, "_local_tensor"):
self._parameters["weight"] = weight._local_tensor
unwrapped = True
try:
original_save(self, destination, prefix, keep_vars)
finally:
if unwrapped:
self._parameters["weight"] = weight
bnb.nn.Linear8bitLt._save_to_state_dict = _patched_save_to_state_dict
apply_linear8bitlt_save_patch._axolotl_patched = True
LOG.info("Patched Linear8bitLt._save_to_state_dict for DTensor compatibility")
def apply_init_dtype_attrs_patch():
"""Prevent FSDP2 mixed precision from casting non-float quantized params.
When mixed precision is enabled (e.g., bf16), FSDP2's init_dtype_attrs sets
param_dtype=bf16 for ALL params. During all-gather, _to_dtype_if_needed casts
the sharded param to param_dtype. For non-float params (uint8 packed 4-bit,
int8 quantized) without FSDP2 extensions, this destroys the quantized data.
Params4bit handles this via fsdp_pre/post_all_gather extensions, but our
parametrize-based expert quantization uses plain nn.Parameter(uint8/int8)
without extensions.
"""
if getattr(apply_init_dtype_attrs_patch, "_axolotl_patched", False):
return
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
original_init_dtype_attrs = FSDPParam.init_dtype_attrs
def patched_init_dtype_attrs(self, mp_policy):
original_init_dtype_attrs(self, mp_policy)
# Skip casting non-float quantized params (uint8/int8) without FSDP2
# extensions — the parametrization chain handles dequantization.
if self.param_dtype is not None and not self.sharded_param.is_floating_point():
local = self.sharded_param
if hasattr(local, "_local_tensor"):
local = local._local_tensor
if not hasattr(local, "fsdp_pre_all_gather"):
self.param_dtype = None
FSDPParam.init_dtype_attrs = patched_init_dtype_attrs
apply_init_dtype_attrs_patch._axolotl_patched = True
LOG.info("Patched FSDPParam.init_dtype_attrs for non-float quantized params")

View File

@@ -59,7 +59,12 @@ class CPU_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad = True
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
output = ctx.forward_function(hidden_states, *ctx.args)
# Newer HF models (e.g. Qwen3MoE) using GradientCheckpointingLayer
# return a plain tensor, not a tuple. Older models return tuples
# like (hidden_states, present_kv, ...). Unwrap if needed.
if isinstance(output, (tuple, list)):
(output,) = output
torch.autograd.backward(output, dY)
return (
None,

View File

@@ -9,6 +9,11 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
try:
from fla.modules.convolution import causal_conv1d as fla_causal_conv1d
except ImportError:
fla_causal_conv1d = None
def get_cu_seqlens(position_ids):
"""
@@ -137,6 +142,11 @@ def patch_qwen3_next_gateddelta_layer():
and cache_position is not None
)
# Compute cu_seqlens early for use by both causal_conv1d and chunk_gated_delta_rule
cu_seqlens = None
if not use_precomputed_states and position_ids is not None:
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
# getting projected states from cache if it exists
if cache_params is not None:
conv_state = cache_params.conv_states[self.layer_idx]
@@ -151,12 +161,11 @@ def patch_qwen3_next_gateddelta_layer():
x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)
)
mixed_qkv = torch.cat((query, key, value), dim=-1)
mixed_qkv = mixed_qkv.transpose(1, 2)
mixed_qkv = torch.cat((query, key, value), dim=-1) # [B, T, D]
if use_precomputed_states:
# 2. Convolution sequence transformation
# NOTE: the conv state is updated in `causal_conv1d_update`
# Inference single-token path: causal_conv1d_update expects [B, D, T]
mixed_qkv = mixed_qkv.transpose(1, 2)
mixed_qkv = self.causal_conv1d_update(
mixed_qkv,
conv_state,
@@ -164,24 +173,41 @@ def patch_qwen3_next_gateddelta_layer():
self.conv1d.bias,
self.activation,
)
mixed_qkv = mixed_qkv.transpose(1, 2)
else:
if cache_params is not None:
# Cache state expects [B, D, T] for the inference update path
mixed_qkv_t = mixed_qkv.transpose(1, 2)
conv_state = F.pad(
mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)
mixed_qkv_t,
(self.conv_kernel_size - mixed_qkv_t.shape[-1], 0),
)
cache_params.conv_states[self.layer_idx] = conv_state
if self.causal_conv1d_fn is not None:
mixed_qkv = self.causal_conv1d_fn(
if fla_causal_conv1d is not None:
# FLA Triton causal_conv1d: [B, T, D] in/out, with cu_seqlens support
mixed_qkv, _ = fla_causal_conv1d(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=None,
cu_seqlens=cu_seqlens,
)
else:
# PyTorch fallback (no cu_seqlens support)
if cu_seqlens is not None and cu_seqlens.shape[0] > batch_size + 1:
raise RuntimeError(
"Packed sequences require fla.modules.convolution.causal_conv1d "
"(cu_seqlens support). Install flash-linear-attention or disable packing."
)
LOG.warning_once(
"FLA causal_conv1d not available. Falling back to PyTorch conv1d."
)
mixed_qkv = mixed_qkv.transpose(1, 2)
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
mixed_qkv = mixed_qkv.transpose(1, 2)
mixed_qkv = mixed_qkv.transpose(1, 2)
# mixed_qkv is [B, T, D] in all paths
query, key, value = torch.split(
mixed_qkv,
[
@@ -203,7 +229,6 @@ def patch_qwen3_next_gateddelta_layer():
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
if not use_precomputed_states:
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
query,
key,

View File

@@ -0,0 +1,188 @@
"""
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
In transformers v5, MoE models store expert weights as fused 3D tensors that BnB
skips (only targets nn.Linear). This module patches weight loading to quantize them
on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametrization),
reducing peak VRAM from "all experts in bf16" to "one expert at a time."
"""
import bitsandbytes as bnb
import torch
import torch.nn.utils.parametrize as P
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# Module-level state for the loading-time quantization patch.
_moe_load_state = {
"count": 0,
"mode": "4bit",
"quant_type": "nf4",
"compress_statistics": True,
"patched": False,
}
class Bnb8bitParametrization(torch.nn.Module):
"""Parametrization that dequantizes int8 row-wise quantized data on access."""
def __init__(self, row_stats: torch.Tensor):
super().__init__()
self.register_buffer("row_stats", row_stats)
@torch.no_grad()
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
# Flatten 3D+ to 2D for BnB's dequant, then reshape back.
orig_shape = quantized_param.shape
if quantized_param.ndim > 2:
quantized_param = quantized_param.reshape(-1, orig_shape[-1])
result = bnb.functional.int8_vectorwise_dequant(quantized_param, self.row_stats)
return result.reshape(orig_shape)
def _enable_parametrization_cache(module, inputs):
P._cache_enabled += 1
def _disable_parametrization_cache(module, inputs, output):
P._cache_enabled -= 1
if not P._cache_enabled:
P._cache = {}
def replace_parameter_8bit(module, param_name):
"""Replace a module parameter with an 8-bit quantized version using parametrization."""
original_param = getattr(module, param_name)
int8_data, row_stats, _ = bnb.functional.int8_vectorwise_quant(
original_param.data.to(torch.float16)
)
setattr(module, param_name, torch.nn.Parameter(int8_data, requires_grad=False))
del original_param
P.register_parametrization(
module, param_name, Bnb8bitParametrization(row_stats), unsafe=True
)
# Cache dequantized values during forward to avoid redundant dequantization.
if not getattr(module, "_axolotl_8bit_hooks_registered", False):
module.register_forward_pre_hook(_enable_parametrization_cache)
module.register_forward_hook(_disable_parametrization_cache)
module._axolotl_8bit_hooks_registered = True
def patch_moe_quantization_on_load(cfg):
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly.
Wraps ``set_param_for_module`` so that 3D+ CUDA tensors with "expert" in their
name are quantized (4-bit or 8-bit) as they're loaded, keeping peak VRAM low.
"""
mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit"
_moe_load_state["mode"] = mode
_moe_load_state["count"] = 0
if _moe_load_state["patched"]:
LOG.debug("MoE loading-time quantization patch already active")
return
import transformers.core_model_loading
import transformers.modeling_utils
if mode == "4bit":
from bitsandbytes.nn.parametrize import replace_parameter_4bit
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)
if compress_statistics is None:
compress_statistics = True
_moe_load_state["quant_type"] = quant_type
_moe_load_state["compress_statistics"] = compress_statistics
# Disable caching_allocator_warmup — it pre-allocates a huge tensor at bf16
# size for all params, defeating our on-load quantization VRAM savings.
def _noop_warmup(*args, **kwargs):
pass
transformers.modeling_utils.caching_allocator_warmup = _noop_warmup
original_set_param = transformers.core_model_loading.set_param_for_module
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
original_set_param(model, target_name, param_value, *args, **kwargs)
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
if param_value.ndim >= 3 and param_value.is_cuda:
mod_path, _, pname = target_name.rpartition(".")
mod = model.get_submodule(mod_path) if mod_path else model
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
if "expert" not in target_name.lower():
LOG.debug(
"Skipping non-expert 3D param: %s (shape=%s)",
target_name,
list(param_value.shape),
)
return
if _moe_load_state["mode"] == "4bit":
replace_parameter_4bit(
mod,
pname,
compress_statistics=_moe_load_state["compress_statistics"],
quant_type=_moe_load_state["quant_type"],
)
else:
replace_parameter_8bit(mod, pname)
_moe_load_state["count"] += 1
# Release the bf16 tensor so CUDA memory is freed immediately.
param_value.data = torch.empty(0, device="cpu")
torch.cuda.empty_cache()
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
_moe_load_state["patched"] = True
def get_moe_quantized_count():
"""Return the number of expert parameters quantized during loading."""
return _moe_load_state["count"]
def patch_peft_target_parameters_matching():
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules."""
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
return
from peft.tuners.tuners_utils import BaseTuner
original_inject = BaseTuner._inject_parameters
def _patched_inject_parameters(
self, peft_config, model, adapter_name, low_cpu_mem_usage
):
# Patch target_parameters to use full paths for parametrized modules
original_targets = list(peft_config.target_parameters)
expanded = set(original_targets)
for module_name, module in model.named_modules():
if not hasattr(module, "parametrizations"):
continue
for target in original_targets:
mod_path, _, param_name = target.rpartition(".")
if (
module_name == mod_path or module_name.endswith("." + mod_path)
) and hasattr(module, param_name):
expanded.add(f"{module_name}.{param_name}")
peft_config.target_parameters = sorted(expanded)
try:
return original_inject(
self, peft_config, model, adapter_name, low_cpu_mem_usage
)
finally:
peft_config.target_parameters = original_targets
BaseTuner._inject_parameters = _patched_inject_parameters
patch_peft_target_parameters_matching._axolotl_patched = True
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")

View File

@@ -28,8 +28,12 @@ PATCHED_EVAL_CODE = {
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
}
ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"
ORIGINAL_MAYBE_CODE = (
"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).mean().item()"
)
PATCHED_MAYBE_CODE = (
"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).nanmean().item()"
)
def check_evaluation_loop_is_patchable() -> bool:

View File

@@ -156,6 +156,10 @@ class TelemetryManager:
Returns:
Boolean denoting whether telemetry is enabled or not.
"""
# Only rank 0 will send telemetry
if not is_main_process():
return False
# Parse relevant env vars
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
do_not_track = os.getenv("DO_NOT_TRACK")
@@ -169,10 +173,6 @@ class TelemetryManager:
):
return True
# Only rank 0 will send telemetry
if not is_main_process():
return False
if do_not_track is None:
do_not_track = "0"

View File

@@ -0,0 +1,84 @@
"""Callback for generating samples during SFT/Pretrain training."""
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from axolotl.utils.generation.sft import generate_samples
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class SFTGenerationCallback(TrainerCallback):
"""Callback for generating samples during SFT/Pretrain training."""
def __init__(self, trainer):
self.trainer = trainer
def on_evaluate(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Generate samples at specified intervals."""
cfg = self.trainer.axolotl_cfg
if not getattr(cfg, "generate_samples", False):
return
dataloader = None
try:
if getattr(self.trainer, "eval_dataset", None) is not None:
dataloader = self.trainer.get_eval_dataloader()
LOG.info(
f"Using eval dataloader for generation at step {state.global_step}"
)
except Exception as e:
LOG.warning(f"Could not get eval dataloader: {e}")
dataloader = None
if dataloader is None:
dataloader = self.trainer.get_train_dataloader()
LOG.info(
f"Using train dataloader for generation at step {state.global_step}"
)
samples = generate_samples(
model=self.trainer.model,
tokenizer=self.trainer.processing_class,
dataloader=dataloader,
num_generation_samples=getattr(cfg, "num_generation_samples", 3),
max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50),
temperature=getattr(cfg, "generation_temperature", 0.7),
top_p=getattr(cfg, "generation_top_p", None),
top_k=getattr(cfg, "generation_top_k", None),
do_sample=getattr(cfg, "generation_do_sample", True),
prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5),
)
self._log_samples(samples, state.global_step)
def _log_samples(self, samples: list, step: int):
"""Log generated samples to console and W&B."""
from axolotl.utils.generation.sft import format_generation_for_logging
for i, sample in enumerate(samples):
console_text, wandb_text = format_generation_for_logging(sample, i, step)
LOG.info(console_text)
try:
import wandb
if wandb.run is not None:
wandb.log(
{
f"samples/sample_{i + 1}": wandb.Html(
f"<pre>{wandb_text}</pre>"
)
},
step=step,
)
except (ImportError, Exception):
pass

View File

@@ -54,15 +54,19 @@ class FileLockLoader:
def cleanup(self):
"""Clean up ready flag when last process is done."""
with FileLock(str(self.lock_file_path)):
counter_content = self.counter_path.read_text().strip()
count = int(counter_content) if counter_content else 0
count -= 1
try:
with FileLock(str(self.lock_file_path)):
counter_content = self.counter_path.read_text().strip()
count = int(counter_content) if counter_content else 0
count -= 1
if count <= 0:
# Last process cleans everything up
self.ready_flag_path.unlink(missing_ok=True)
self.counter_path.unlink(missing_ok=True)
else:
# Still have active processes
self.counter_path.write_text(str(count))
if count <= 0:
# Last process cleans everything up
self.ready_flag_path.unlink(missing_ok=True)
self.counter_path.unlink(missing_ok=True)
else:
# Still have active processes
self.counter_path.write_text(str(count))
except FileNotFoundError:
# Lock file might have already been deleted by another process
pass

View File

@@ -246,6 +246,10 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
dataset = merge_datasets(split_datasets, cfg)
if not cfg.skip_prepare_dataset:
# Deduplicate before saving so the saved dataset is already de-duplicated
if cfg.dataset_exact_deduplication:
dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
# Save preprocessed dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path

View File

@@ -351,6 +351,10 @@ def _load_raw_datasets(
if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
# Deduplicate before saving so the saved dataset is already de-duplicated
if cfg.dataset_exact_deduplication:
dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
# Save the prepared dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
@@ -438,25 +442,8 @@ def _handle_train_dataset_split(
)
return train_dataset, eval_dataset
# No validation split - apply deduplication if needed and return as train dataset
if cfg.dataset_exact_deduplication:
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
else:
train_dataset = dataset
return train_dataset, None
def _handle_test_dataset_split(
dataset: Dataset, cfg: DictDefault
) -> tuple[None, Dataset | None]:
"""Handle processing for test split."""
if cfg.dataset_exact_deduplication:
eval_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
else:
eval_dataset = dataset
return None, eval_dataset
# No validation split - deduplication already applied during preprocessing
return dataset, None
def _apply_dataset_sharding(dataset: Dataset, cfg: DictDefault) -> Dataset:
@@ -515,6 +502,7 @@ def _load_and_prepare_datasets(
if split == "train":
train_dataset, eval_dataset = _handle_train_dataset_split(dataset, cfg)
else:
train_dataset, eval_dataset = _handle_test_dataset_split(dataset, cfg)
# Deduplication already applied during preprocessing
train_dataset, eval_dataset = None, dataset
return train_dataset, eval_dataset, prompters

View File

@@ -520,7 +520,8 @@ def generate_dataset_hash_from_config(
"""
config_str = (
f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@"
f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}|"
f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}@"
f"{cfg.dataset_exact_deduplication or False}|"
f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}"
f"|{tokenizer_name}"
)

View File

@@ -15,7 +15,7 @@ from datasets import Dataset, IterableDataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers.utils import get_dataset_lengths
from axolotl.utils.trainer import drop_long_seq
from axolotl.utils.trainer import filter_sequences_by_length
LOG = get_logger(__name__)
@@ -148,22 +148,33 @@ def deduplicate_and_log_datasets(
return dataset, other_dataset
def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def keep_min_len(sample, min_sequence_len=2):
"""
Truncate samples whose sequence length is too long (> sequence_len)
or drop those too short (< min_sequence_len).
Batched filter function that keeps only samples with sequence length >= min_sequence_len.
Returns a list of booleans indicating which samples to keep.
"""
min_sequence_len = min_sequence_len or 2
input_ids = sample["input_ids"]
# Batched (input_ids is a list of lists)
results = []
for seq in input_ids:
results.append(len(seq) >= min_sequence_len)
return results
def truncate_long_seq(sample, sequence_len=2048):
"""
Truncate samples whose sequence length is too long (> sequence_len).
Modifies the sample in-place and returns the modified sample.
"""
input_ids = sample["input_ids"]
# Batched (input_ids is a list of lists)
for i, seq in enumerate(input_ids):
length = len(seq)
if length < min_sequence_len:
results.append(False)
elif length > sequence_len:
if length > sequence_len:
sample["input_ids"][i] = seq[:sequence_len]
if "attention_mask" in sample:
sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len]
@@ -171,10 +182,133 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
sample["labels"][i] = sample["labels"][i][:sequence_len]
if "position_ids" in sample:
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len]
results.append(True)
else:
results.append(True)
return results
return sample
def _should_skip_processing(dataset: Dataset) -> bool:
"""Check if dataset should skip long sequence handling."""
if (
hasattr(dataset, "column_names")
and dataset.column_names
and "input_ids" not in dataset.column_names
):
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
)
return True
elif not hasattr(dataset, "column_names") or dataset.column_names is None:
LOG.info(
"Dataset is streaming (IterableDataset), skipping long sequence handling"
)
return True
return False
def _log_dataset_stats(dataset: Dataset) -> None:
"""Log min/max sequence lengths for debugging."""
with contextlib.suppress(AttributeError, ValueError):
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
LOG.info(f"min_input_len: {np.min(ds_lengths)}")
LOG.info(f"max_input_len: {np.max(ds_lengths)}")
def _build_filter_kwargs(dataset: Dataset, cfg: DictDefault) -> dict:
"""Build kwargs for dataset filter/map operations."""
kwargs = {}
if not isinstance(dataset, IterableDataset):
kwargs["num_proc"] = cfg.dataset_num_proc
kwargs["load_from_cache_file"] = not cfg.is_preprocess
return kwargs
def _filter_short_sequences(
dataset: Dataset, min_len: int, filter_kwargs: dict
) -> tuple[Dataset, int]:
"""Filter out sequences shorter than min_len. Returns (dataset, num_dropped)."""
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
desc_kwargs = {}
if filter_kwargs:
desc_kwargs["desc"] = f"Filtering Short Sequences (<{min_len})"
dataset = dataset.filter(
functools.partial(keep_min_len, min_sequence_len=min_len),
batched=True,
**filter_kwargs,
**desc_kwargs,
)
dropped = 0
if prior_len:
dropped = prior_len - len(dataset)
if dropped > 0:
LOG.info(f"Dropped {dropped} short sequences (<{min_len} tokens)")
return dataset, dropped
def _truncate_long_sequences(
dataset: Dataset, max_len: int, map_kwargs: dict
) -> Dataset:
"""Truncate sequences longer than max_len."""
desc_kwargs = {}
if map_kwargs:
desc_kwargs["desc"] = f"Truncating Sequences (target_len={max_len})"
dataset = dataset.map(
functools.partial(truncate_long_seq, sequence_len=max_len),
batched=True,
**map_kwargs,
**desc_kwargs,
)
LOG.info(f"Truncated long sequences to max length {max_len}")
return dataset
def _drop_outside_range(
dataset: Dataset,
max_len: int,
min_len: int,
raise_on_long: bool,
filter_kwargs: dict,
) -> tuple[Dataset, int]:
"""Drop sequences outside valid length range [min_len, max_len].
Returns (dataset, num_dropped)."""
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
desc_kwargs = {}
if filter_kwargs:
action = (
"Checking Sequence Lengths"
if raise_on_long
else "Dropping Invalid Sequences"
)
desc_kwargs["desc"] = f"{action} (<{min_len} or >{max_len})"
dataset = dataset.filter(
functools.partial(
filter_sequences_by_length,
sequence_len=max_len,
min_sequence_len=min_len,
raise_on_drop=raise_on_long,
),
batched=True,
**filter_kwargs,
**desc_kwargs,
)
dropped = 0
if not raise_on_long and prior_len:
dropped = prior_len - len(dataset)
if dropped > 0:
LOG.info(
f"Dropped {dropped} sequences outside valid range "
f"([{min_len}, {max_len}])"
)
return dataset, dropped
def handle_long_seq_in_dataset(
@@ -193,80 +327,25 @@ def handle_long_seq_in_dataset(
'truncate' truncates them down to sequence_len
'raise' raises a ValueError if any sequence was found that was longer than sequence_len
"""
if (
hasattr(dataset, "column_names")
and dataset.column_names
and "input_ids" not in dataset.column_names
):
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
)
return dataset
elif not hasattr(dataset, "column_names") or dataset.column_names is None:
LOG.info(
"Dataset is streaming (IterableDataset), skipping long sequence handling"
)
# Early returns for special cases
if _should_skip_processing(dataset):
return dataset
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
drop_long = functools.partial(
drop_long_seq,
sequence_len=sequence_len,
min_sequence_len=cfg.min_sample_len,
raise_on_drop=excess_length_strategy == "raise",
)
_log_dataset_stats(dataset)
with contextlib.suppress(AttributeError):
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
min_input_len = np.min(ds_lengths)
LOG.info(f"min_input_len: {min_input_len}")
max_input_len = np.max(ds_lengths)
LOG.info(f"max_input_len: {max_input_len}")
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
filter_map_kwargs = {}
if not isinstance(dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
drop_long_kwargs = {}
if filter_map_kwargs:
action = (
"Checking Sequence Lengths"
if excess_length_strategy == "raise"
else "Dropping Long Sequences"
)
drop_long_kwargs["desc"] = f"{action} (>{sequence_len})"
# Setup kwargs
filter_kwargs = _build_filter_kwargs(dataset, cfg)
# Handle sequences based on strategy
if excess_length_strategy == "truncate":
process_fn = functools.partial(
truncate_long_seq,
sequence_len=sequence_len,
min_sequence_len=cfg.min_sample_len,
)
drop_long_kwargs["desc"] = (
f"Truncating/Filtering Sequences (target_len={sequence_len})"
)
dataset, _ = _filter_short_sequences(dataset, cfg.min_sample_len, filter_kwargs)
dataset = _truncate_long_sequences(dataset, sequence_len, filter_kwargs)
else:
process_fn = drop_long
dataset = dataset.filter(
process_fn,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(dataset)
if dropped:
action = (
"truncated/filtered"
if excess_length_strategy == "truncate"
else "dropped"
)
LOG.warning(f"{action.title()} {dropped} samples from dataset")
raise_on_long = excess_length_strategy == "raise"
dataset, _ = _drop_outside_range(
dataset, sequence_len, cfg.min_sample_len, raise_on_long, filter_kwargs
)
return dataset

View File

@@ -0,0 +1,5 @@
"""Generation utilities for monitoring during training."""
from .sft import format_generation_for_logging, generate_samples
__all__ = ["generate_samples", "format_generation_for_logging"]

View File

@@ -0,0 +1,174 @@
"""Sample generation utilities for SFT/Pretrain training."""
from typing import Any, List, Optional
import torch
from accelerate.utils import extract_model_from_parallel
from colorama import Fore, Style
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def generate_samples(
model: torch.nn.Module,
tokenizer: Any,
dataloader: Any,
num_generation_samples: int = 3,
max_new_tokens: int = 50,
temperature: float = 0.7,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
do_sample: bool = True,
prompt_ratio: float = 0.5,
) -> List[dict]:
"""
Generate samples from the model during training for monitoring.
Args:
model: The model to generate from
tokenizer: The tokenizer to use for encoding/decoding
dataloader: Dataloader to sample prompts from
num_generation_samples: Number of samples to generate
max_new_tokens: Maximum new tokens to generate
temperature: Sampling temperature (0.0 = greedy)
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
do_sample: Whether to use sampling vs greedy decoding
prompt_ratio: Ratio of sequence to use as prompt (0.0-1.0)
Returns:
List of dicts with 'prompt', 'generated', and 'full_text' keys
"""
unwrapped_model = extract_model_from_parallel(model)
training = unwrapped_model.training
unwrapped_model.eval()
device = next(unwrapped_model.parameters()).device
generations = []
try:
with torch.no_grad():
samples_collected = 0
for batch in dataloader:
if samples_collected >= num_generation_samples:
break
input_ids = batch["input_ids"].to(device)
attention_mask = batch.get("attention_mask")
if attention_mask is not None:
attention_mask = attention_mask.to(device)
batch_size = input_ids.shape[0]
indices = torch.randperm(batch_size)[
: num_generation_samples - samples_collected
]
for idx in indices:
if samples_collected >= num_generation_samples:
break
sequence = input_ids[idx]
if attention_mask is not None:
seq_len = attention_mask[idx].sum().item()
else:
seq_len = sequence.shape[0]
if seq_len < 5:
continue
prompt_len = max(1, int(seq_len * prompt_ratio))
prompt_ids = sequence[:prompt_len].unsqueeze(0)
try:
generation_config = {
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
"pad_token_id": tokenizer.pad_token_id
if tokenizer.pad_token_id is not None
else tokenizer.eos_token_id,
}
if do_sample:
generation_config["temperature"] = temperature
if top_p is not None:
generation_config["top_p"] = top_p
if top_k is not None:
generation_config["top_k"] = top_k
generated_ids = unwrapped_model.generate(
prompt_ids, **generation_config
)
prompt_text = tokenizer.decode(
prompt_ids[0], skip_special_tokens=True
)
generated_text = tokenizer.decode(
generated_ids[0][prompt_len:], skip_special_tokens=True
)
full_text = tokenizer.decode(
generated_ids[0], skip_special_tokens=True
)
generations.append(
{
"prompt": prompt_text,
"generated": generated_text,
"full_text": full_text,
}
)
samples_collected += 1
except Exception as e:
LOG.warning(f"Failed to generate sample: {e}", exc_info=True)
continue
except Exception as e:
LOG.warning(f"Error during sample generation: {e}", exc_info=True)
if training:
unwrapped_model.train()
else:
unwrapped_model.eval()
return generations
def format_generation_for_logging(
sample: dict, sample_idx: int, step: int
) -> tuple[str, str]:
"""
Format a generation sample for pretty logging.
Args:
sample: Dict with 'prompt', 'generated', and 'full_text' keys
sample_idx: Index of the sample
step: Current training step
Returns:
Tuple of (console_text, wandb_text)
"""
console_text = (
f"\n{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\n"
f"{Style.BRIGHT}{Fore.GREEN}Sample {sample_idx + 1} (Step {step}){Style.RESET_ALL}\n"
f"{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\n"
f"{Style.BRIGHT}{Fore.YELLOW}[PROMPT]{Style.RESET_ALL}\n{sample['prompt']}\n\n"
f"{Style.BRIGHT}{Fore.MAGENTA}[GENERATED]{Style.RESET_ALL}\n{sample['generated']}\n"
f"{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\n"
)
wandb_text = (
f"\n{'=' * 80}\n"
f"Sample {sample_idx + 1} (Step {step})\n"
f"{'=' * 80}\n"
f"[PROMPT]\n{sample['prompt']}\n\n"
f"[GENERATED]\n{sample['generated']}\n"
f"{'=' * 80}\n"
)
return console_text, wandb_text

View File

@@ -30,18 +30,8 @@ class Mistral3Processor(ProcessorMixin):
Wraps HFMistralTokenizer and adds image processing capabilities.
"""
# TODO(nano): This should be removed in transformers V5
attributes = ["tokenizer"]
tokenizer_class = "HFMistralTokenizer"
def __init__(self, tokenizer: HFMistralTokenizer):
# Don't call super().__init__ to avoid the class validation issue
self.tokenizer = tokenizer
@property
def chat_template(self) -> None:
"""Chat template is not supported. Dummy method to satisfy HuggingFace API."""
return None
super().__init__(tokenizer)
@property
def audio_tokenizer(self) -> None:

View File

@@ -338,18 +338,6 @@ class AxolotlInputConfig(
)
ddp_find_unused_parameters: bool | None = None
eval_table_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0"
},
)
eval_max_new_tokens: int | None = Field(
default=None,
json_schema_extra={
"description": "Total number of tokens generated for predictions sent to wandb. Default is 128"
},
)
do_causal_lm_eval: bool | None = Field(
default=None,
json_schema_extra={
@@ -446,7 +434,16 @@ class AxolotlInputConfig(
},
)
unfrozen_parameters: list[str] | None = None
unfrozen_parameters: list[str] | None = Field(
default=None,
json_schema_extra={
"description": "List of regex patterns for parameter names to keep unfrozen. "
"All other parameters will be frozen via requires_grad=False. "
"Note: range-based patterns (e.g. embed_tokens.weight$[:32000]) use gradient "
"zeroing rather than a true freeze, so weight decay will still apply to the "
"frozen portion and optimizer states are allocated for the full parameter."
},
)
sequence_len: int = Field(
default=512,
@@ -632,6 +629,17 @@ class AxolotlInputConfig(
},
)
quantize_moe_experts: bool = Field(
default=False,
json_schema_extra={
"description": "Quantize MoE expert weights on load to reduce VRAM. "
"Requires adapter (lora/qlora) with load_in_4bit or load_in_8bit. "
"Requires CUDA (not compatible with ROCm or other backends). "
"Note: total parameter count may be reported incorrectly when enabled "
"(trainable param count is correct)."
},
)
scaling_softmax: bool | None = Field(
default=None,
json_schema_extra={
@@ -1097,6 +1105,46 @@ class AxolotlInputConfig(
"description": "Add plugins to extend the pipeline. See `src/axolotl/integrations` for the available plugins or doc below for more details. https://docs.axolotl.ai/docs/custom_integrations.html"
},
)
generate_samples: bool | None = Field(
default=False,
json_schema_extra={
"description": "Enable sample generation during training for monitoring"
},
)
num_generation_samples: int | None = Field(
default=3,
json_schema_extra={
"description": "Number of samples to generate at each interval"
},
)
generation_max_new_tokens: int | None = Field(
default=50,
json_schema_extra={"description": "Maximum new tokens to generate per sample"},
)
generation_temperature: float | None = Field(
default=0.7,
json_schema_extra={
"description": "Temperature for sample generation (0.0 = greedy)"
},
)
generation_top_p: float | None = Field(
default=None,
json_schema_extra={"description": "Nucleus sampling parameter for generation"},
)
generation_top_k: int | None = Field(
default=None,
json_schema_extra={"description": "Top-k sampling parameter for generation"},
)
generation_prompt_ratio: float | None = Field(
default=0.5,
json_schema_extra={"description": "Ratio of input to use as prompt (0.0-1.0)"},
)
generation_do_sample: bool | None = Field(
default=True,
json_schema_extra={
"description": "Whether to use sampling (vs greedy decoding)"
},
)
@field_serializer("datasets")
def datasets_serializer(
@@ -1252,6 +1300,26 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return data
@model_validator(mode="before")
@classmethod
def check_quantize_moe_experts(cls, data):
if data.get("quantize_moe_experts"):
if data.get("adapter") not in ("lora", "qlora"):
raise ValueError("quantize_moe_experts requires adapter: lora or qlora")
if not (data.get("load_in_4bit") or data.get("load_in_8bit")):
raise ValueError(
"quantize_moe_experts requires load_in_4bit or load_in_8bit"
)
if (
data.get("capabilities")
and data["capabilities"].get("compute_capability")
and not data["capabilities"]["compute_capability"].startswith("sm_")
):
raise ValueError(
"quantize_moe_experts requires CUDA (not compatible with ROCm or other backends)"
)
return data
@model_validator(mode="before")
@classmethod
def check_auto_enable_lora_kernels(cls, data):
@@ -1472,3 +1540,16 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
"dataset_exact_deduplication is not available for streaming datasets. "
)
return data
@model_validator(mode="before")
@classmethod
def check_deduplication_with_skip_prepare(cls, data):
if data.get("dataset_exact_deduplication") and data.get("skip_prepare_dataset"):
raise ValueError(
"dataset_exact_deduplication=True has no effect when "
"skip_prepare_dataset=True. Deduplication runs as part of the "
"prepare pipeline, which is skipped. Either set "
"skip_prepare_dataset: false or disable "
"dataset_exact_deduplication."
)
return data

View File

@@ -17,6 +17,8 @@ class DeprecatedParameters(BaseModel):
noisy_embedding_alpha: float | None = None
dpo_beta: float | None = None
evaluation_strategy: str | None = None
eval_table_size: int | None = None
eval_max_new_tokens: int | None = None
@field_validator("max_packed_sequence_len")
@classmethod
@@ -55,6 +57,27 @@ class DeprecatedParameters(BaseModel):
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
return evaluation_strategy
@field_validator("eval_table_size")
@classmethod
def validate_eval_table_size(cls, eval_table_size):
if eval_table_size is not None:
LOG.warning(
"eval_table_size is deprecated and superseded by generate_samples config. "
"Please use generate_samples: true and num_generation_samples instead. "
"The LogPredictionCallback is replaced by the new sample generation feature."
)
return eval_table_size
@field_validator("eval_max_new_tokens")
@classmethod
def validate_eval_max_new_tokens(cls, eval_max_new_tokens):
if eval_max_new_tokens is not None:
LOG.warning(
"eval_max_new_tokens is deprecated and superseded by generate_samples config. "
"Please use generation_max_new_tokens instead."
)
return eval_max_new_tokens
class RemappedParameters(BaseModel):
"""Parameters that have been remapped to other names"""

View File

@@ -1,6 +1,6 @@
"""Pydantic models for PEFT-related configuration"""
from typing import Any
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator, model_validator
@@ -38,10 +38,10 @@ class LoraConfig(BaseModel):
default=False, json_schema_extra={"description": "Use bitsandbytes 4 bit"}
)
adapter: str | None = Field(
adapter: Literal["lora", "qlora", "llama-adapter"] | None = Field(
default=None,
json_schema_extra={
"description": "If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model"
"description": "If you want to use 'lora', 'qlora', or 'llama-adapter', or leave blank to train all parameters in original model"
},
)
lora_model_dir: str | None = Field(
@@ -209,6 +209,19 @@ class LoraConfig(BaseModel):
data["lora_dropout"] = 0.0
return data
@model_validator(mode="after")
def validate_lora_target_parameters_dropout(self):
if (
self.lora_target_parameters
and self.lora_dropout
and self.lora_dropout != 0.0
):
raise ValueError(
"lora_dropout must be 0 when lora_target_parameters is set. "
"PEFT's ParamWrapper does not support lora_dropout != 0."
)
return self
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""

View File

@@ -205,10 +205,13 @@ def add_length(sample):
return sample
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False):
def filter_sequences_by_length(
sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False
):
"""
Drop samples whose sequence length is either too long (> sequence_len)
or too short (< min_sequence_len).
Filter sequences outside valid length range [min_sequence_len, sequence_len].
Drops samples that are either too short (< min_sequence_len) or too long (> sequence_len).
Works for both single-example (list[int]) or batched (list[list[int]]).
@@ -383,10 +386,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
def process_pretraining_datasets_for_packing(
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False
):
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
drop_outside_range = partial(filter_sequences_by_length, sequence_len=sequence_len)
train_dataset = train_dataset.filter(
drop_long,
drop_outside_range,
desc="Dropping Long Sequences",
load_from_cache_file=False,
)
@@ -480,7 +483,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
bin_size=cfg.sample_packing_bin_size,
sequential=cfg.sample_packing_sequentially,
drop_last=True,
num_processes=cfg.dataset_prcoesses,
num_processes=cfg.dataset_num_proc,
mp_start_method=cfg.sample_packing_mp_start_method or "fork",
)

View File

@@ -0,0 +1,227 @@
"""Tests for nested config option handling via CLI dot-notation."""
import click
from click.testing import CliRunner
from pydantic import BaseModel, Field
from axolotl.cli.utils.args import add_options_from_config, filter_none_kwargs
class InnerConfig(BaseModel):
"""A nested config model for testing."""
beta: float | None = Field(
default=None,
description="Beta parameter.",
)
host: str | None = Field(
default=None,
description="Server host.",
)
use_feature: bool = Field(
default=False,
description="Whether to use the feature.",
)
class OuterConfig(BaseModel):
"""A top-level config model for testing."""
learning_rate: float | None = Field(
default=None,
description="Learning rate.",
)
inner: InnerConfig | None = Field(
default=None,
description="Inner config.",
)
name: str | None = Field(
default=None,
description="Model name.",
)
class TestAddOptionsFromConfigNested:
"""Test that add_options_from_config handles nested BaseModel fields."""
def setup_method(self):
self.runner = CliRunner()
def test_nested_dot_notation_options_are_registered(self):
"""Nested model fields should create --parent.child CLI options."""
@click.command()
@add_options_from_config(OuterConfig)
@filter_none_kwargs
def cmd(**kwargs):
for k, v in sorted(kwargs.items()):
click.echo(f"{k}={v}")
result = self.runner.invoke(cmd, ["--inner.beta=0.5", "--inner.host=localhost"])
assert result.exit_code == 0, result.output
assert "inner__beta=0.5" in result.output
assert "inner__host=localhost" in result.output
def test_nested_bool_option(self):
"""Nested bool fields should support --parent.field/--no-parent.field."""
@click.command()
@add_options_from_config(OuterConfig)
@filter_none_kwargs
def cmd(**kwargs):
for k, v in sorted(kwargs.items()):
click.echo(f"{k}={v}")
result = self.runner.invoke(cmd, ["--inner.use-feature"])
assert result.exit_code == 0, result.output
assert "inner__use_feature=True" in result.output
def test_flat_and_nested_options_together(self):
"""Flat and nested options should work together."""
@click.command()
@add_options_from_config(OuterConfig)
@filter_none_kwargs
def cmd(**kwargs):
for k, v in sorted(kwargs.items()):
click.echo(f"{k}={v}")
result = self.runner.invoke(
cmd, ["--learning-rate=0.001", "--inner.beta=0.1", "--name=test"]
)
assert result.exit_code == 0, result.output
assert "learning_rate=0.001" in result.output
assert "inner__beta=0.1" in result.output
assert "name=test" in result.output
def test_no_nested_options_passed(self):
"""When no nested options are passed, they should not appear in kwargs."""
@click.command()
@add_options_from_config(OuterConfig)
@filter_none_kwargs
def cmd(**kwargs):
click.echo(f"keys={sorted(kwargs.keys())}")
result = self.runner.invoke(cmd, ["--learning-rate=0.01"])
assert result.exit_code == 0, result.output
assert "inner__" not in result.output
class TestLoadCfgNestedKwargs:
"""Test that load_cfg correctly applies nested (double-underscore) kwargs."""
@staticmethod
def _apply_nested_kwargs(cfg, kwargs):
"""Helper that mirrors the nested kwargs handling from load_cfg,
including type coercion for string CLI values."""
from axolotl.cli.config import _coerce_value
nested_kwargs: dict = {}
flat_kwargs: dict = {}
for key, value in kwargs.items():
if "__" in key:
parent, child = key.split("__", 1)
nested_kwargs.setdefault(parent, {})[child] = value
else:
flat_kwargs[key] = value
cfg_keys = cfg.keys()
for key, value in flat_kwargs.items():
if key in cfg_keys:
cfg[key] = _coerce_value(value, cfg.get(key))
for parent, children in nested_kwargs.items():
if cfg[parent] is None:
cfg[parent] = {}
if not isinstance(cfg[parent], dict):
cfg[parent] = {}
for child_key, child_value in children.items():
existing = cfg[parent].get(child_key)
cfg[parent][child_key] = _coerce_value(child_value, existing)
return cfg
def test_nested_kwargs_applied_to_cfg(self, tmp_path):
"""Double-underscore kwargs should set nested config values."""
from axolotl.utils.dict import DictDefault
cfg = DictDefault({"trl": {"beta": 0.1}, "learning_rate": 0.01})
# CLI passes strings, so simulate that
kwargs = {
"trl__beta": "0.5",
"trl__host": "192.168.1.1",
"learning_rate": "0.02",
}
cfg = self._apply_nested_kwargs(cfg, kwargs)
assert cfg["learning_rate"] == 0.02
assert isinstance(cfg["learning_rate"], float)
assert cfg["trl"]["beta"] == 0.5
assert isinstance(cfg["trl"]["beta"], float)
assert cfg["trl"]["host"] == "192.168.1.1"
def test_nested_kwargs_creates_parent_if_none(self):
"""If the parent key is None, nested kwargs should create the dict."""
from axolotl.utils.dict import DictDefault
cfg = DictDefault({"trl": None, "learning_rate": 0.01})
cfg = self._apply_nested_kwargs(cfg, {"trl__beta": "0.5"})
# No existing value, YAML-style inference: "0.5" -> 0.5
assert cfg["trl"]["beta"] == 0.5
assert isinstance(cfg["trl"]["beta"], float)
def test_nested_kwargs_overwrites_string_parent(self):
"""If the parent key is a string, it should be replaced with a dict."""
from axolotl.utils.dict import DictDefault
cfg = DictDefault({"trl": "some_string", "learning_rate": 0.01})
cfg = self._apply_nested_kwargs(cfg, {"trl__beta": "0.5"})
assert cfg["trl"]["beta"] == 0.5
class TestCoerceValue:
"""Test YAML-style type coercion for CLI string values."""
def test_coerce_with_existing_float(self):
from axolotl.cli.config import _coerce_value
assert _coerce_value("0.5", 0.1) == 0.5
assert isinstance(_coerce_value("0.5", 0.1), float)
def test_coerce_with_existing_int(self):
from axolotl.cli.config import _coerce_value
assert _coerce_value("42", 10) == 42
assert isinstance(_coerce_value("42", 10), int)
def test_coerce_with_existing_bool(self):
from axolotl.cli.config import _coerce_value
assert _coerce_value("true", False) is True
assert _coerce_value("false", True) is False
assert _coerce_value("1", False) is True
assert _coerce_value("0", True) is False
def test_coerce_yaml_inference_no_existing(self):
"""Without an existing value, use YAML-style inference."""
from axolotl.cli.config import _coerce_value
assert _coerce_value("true", None) is True
assert _coerce_value("false", None) is False
assert _coerce_value("42", None) == 42
assert isinstance(_coerce_value("42", None), int)
assert _coerce_value("3.14", None) == 3.14
assert isinstance(_coerce_value("3.14", None), float)
assert _coerce_value("null", None) is None
assert _coerce_value("hello", None) == "hello"
def test_coerce_non_string_passthrough(self):
"""Non-string values should pass through unchanged."""
from axolotl.cli.config import _coerce_value
assert _coerce_value(0.5, 0.1) == 0.5
assert _coerce_value(True, False) is True

View File

@@ -300,7 +300,6 @@ class TestHFRLTrainerBuilder:
self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl)
# ORPO specific
assert training_arguments.beta == 0.1 # maps from orpo_alpha
assert training_arguments.max_prompt_length == 512
def test_kto_training_arguments(self, kto_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer)

View File

@@ -186,6 +186,7 @@ class TestFSDP1:
verify_training_success(temp_dir)
@pytest.mark.skip(reason="slow test, deprecate fsdp1 asap")
def test_dpo_fft(self, temp_dir):
cfg = DictDefault(
{

View File

@@ -365,6 +365,7 @@ class TestFSDP2:
verify_training_success(temp_dir)
@pytest.mark.skip(reason="slow test w cu129 + torch 2.9.1 + py3.12")
@require_torch_2_7_0
def test_dpo_fft(self, temp_dir):
cfg = DictDefault(
@@ -422,6 +423,7 @@ class TestFSDP2:
verify_training_success(temp_dir)
@pytest.mark.skip(reason="slow test w cu129 + torch 2.9.1 + py3.12")
@require_torch_2_7_0
def test_dpo_lora(self, temp_dir):
cfg = DictDefault(

View File

@@ -0,0 +1,323 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
Unit tests for scattermoe-lora code-review fixes.
Tests cover:
- KernelsArgs validator: disable_mlp_kernel_scattermoe
- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward
- ParallelExperts: scaling=0.0 not treated as falsy
- single2scatter: non-aligned K/N dimensions
- group_compileable: coeff=None accepted
- HFScatterMoEGatedMLP / ScatterMoEGatedMLP: return value contract
"""
from unittest.mock import patch
import pytest
import torch
# ============================================================================
# 1. KernelsArgs: disable_mlp_kernel_scattermoe validator
# ============================================================================
class TestKernelsArgsValidator:
"""Test that disable_mlp_kernel_scattermoe sets both flags correctly.
These tests call the validator classmethod directly on raw dicts,
since lora_mlp_kernel / mlp_kernel are not declared model fields.
"""
def test_disables_lora_mlp_kernel_when_scattermoe(self):
"""lora_mlp_kernel=True gets set to False when use_scattermoe=True."""
from axolotl.integrations.kernels.args import KernelsArgs
data = {
"use_kernels": True,
"use_scattermoe": True,
"lora_mlp_kernel": True,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
assert result["lora_mlp_kernel"] is False
assert result["mlp_kernel"] is False
def test_mlp_kernel_disabled_without_lora(self):
"""Even without lora_mlp_kernel, mlp_kernel should be disabled."""
from axolotl.integrations.kernels.args import KernelsArgs
data = {
"use_kernels": True,
"use_scattermoe": True,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
assert result["mlp_kernel"] is False
# lora_mlp_kernel was not in data, should not be added
assert "lora_mlp_kernel" not in result
def test_lora_mlp_kernel_false_unchanged(self):
"""lora_mlp_kernel=False should stay False (no warning, no change)."""
from axolotl.integrations.kernels.args import KernelsArgs
data = {
"use_kernels": True,
"use_scattermoe": True,
"lora_mlp_kernel": False,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
assert result["lora_mlp_kernel"] is False
def test_no_change_when_scattermoe_disabled(self):
"""When use_scattermoe is not True, nothing should be changed."""
from axolotl.integrations.kernels.args import KernelsArgs
data = {
"use_kernels": True,
"use_scattermoe": False,
"lora_mlp_kernel": True,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
assert result["lora_mlp_kernel"] is True
class TestParallelExpertsScaling:
"""Test that scaling=0.0 is preserved and not overridden to 1.0."""
def test_scaling_zero_preserved(self):
"""scaling=0.0 should be passed as 0.0, not replaced with 1.0."""
pytest.importorskip("triton")
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (
ParallelExperts,
)
pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)
pe.set_lora(
lora_A=torch.randn(4, 4),
lora_B=torch.randn(4, 4),
scaling=0.0,
)
assert pe._lora_scaling == 0.0
# Patch parallel_linear_lora to capture the scaling arg
with patch(
"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora"
) as mock_pll:
mock_pll.return_value = torch.randn(4, 4)
# Create dummy routing tensors
pe.forward(
inputs=torch.randn(2, 4),
k=1,
sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),
sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),
expert_offsets=torch.tensor([2, 4]),
)
# Check that scaling=0.0 was passed, not 1.0
call_kwargs = mock_pll.call_args
assert (
call_kwargs.kwargs.get("scaling") == 0.0
or call_kwargs[1].get("scaling") == 0.0
), f"Expected scaling=0.0 but got {call_kwargs}"
def test_scaling_none_defaults_to_one(self):
"""scaling=None (no LoRA attached) should default to 1.0."""
pytest.importorskip("triton")
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (
ParallelExperts,
)
pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)
# No set_lora called, so _lora_scaling is None
with patch(
"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora"
) as mock_pll:
mock_pll.return_value = torch.randn(4, 4)
pe.forward(
inputs=torch.randn(2, 4),
k=1,
sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),
sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),
expert_offsets=torch.tensor([2, 4]),
)
call_kwargs = mock_pll.call_args
scaling_val = call_kwargs.kwargs.get("scaling") or call_kwargs[1].get(
"scaling"
)
assert scaling_val == 1.0, (
f"Expected scaling=1.0 for None but got {scaling_val}"
)
def test_scaling_positive_preserved(self):
"""Normal positive scaling should be preserved."""
pytest.importorskip("triton")
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (
ParallelExperts,
)
pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)
pe.set_lora(
lora_A=torch.randn(4, 4),
lora_B=torch.randn(4, 4),
scaling=0.5,
)
with patch(
"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora"
) as mock_pll:
mock_pll.return_value = torch.randn(4, 4)
pe.forward(
inputs=torch.randn(2, 4),
k=1,
sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),
sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),
expert_offsets=torch.tensor([2, 4]),
)
call_kwargs = mock_pll.call_args
scaling_val = call_kwargs.kwargs.get("scaling") or call_kwargs[1].get(
"scaling"
)
assert scaling_val == 0.5
# ============================================================================
# 4. single2scatter: non-aligned K/N dimensions (GPU only)
# ============================================================================
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
class TestSingle2ScatterBounds:
"""Test single2scatter with non-aligned dimensions."""
def test_non_aligned_k(self):
"""K not a multiple of BLOCK_K should produce correct results."""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (
single2scatter,
)
E, K, N = 2, 100, 128 # K=100 not a multiple of 128
W = torch.randn(E, K, N, device="cuda", dtype=torch.float32)
X = torch.randn(1, K, device="cuda", dtype=torch.float32)
expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long)
Y = single2scatter(X, W, expert_idxs)
assert Y.shape == (2, N)
# Verify against manual computation
Y_ref_0 = X[0] @ W[0]
Y_ref_1 = X[0] @ W[1]
torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)
def test_non_aligned_n(self):
"""N not a multiple of BLOCK_N should produce correct results."""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (
single2scatter,
)
E, K, N = 2, 128, 100 # N=100 not a multiple of 128
W = torch.randn(E, K, N, device="cuda", dtype=torch.float32)
X = torch.randn(1, K, device="cuda", dtype=torch.float32)
expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long)
Y = single2scatter(X, W, expert_idxs)
assert Y.shape == (2, N)
Y_ref_0 = X[0] @ W[0]
Y_ref_1 = X[0] @ W[1]
torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)
def test_non_aligned_both(self):
"""Both K and N not aligned should produce correct results."""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (
single2scatter,
)
E, K, N = 2, 100, 100 # Neither aligned to 128
W = torch.randn(E, K, N, device="cuda", dtype=torch.float32)
X = torch.randn(1, K, device="cuda", dtype=torch.float32)
expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long)
Y = single2scatter(X, W, expert_idxs)
assert Y.shape == (2, N)
Y_ref_0 = X[0] @ W[0]
Y_ref_1 = X[0] @ W[1]
torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)
# ============================================================================
# 5. group_compileable: coeff=None accepted
# ============================================================================
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
class TestGroupCoeffNone:
"""Test that group() works with coeff=None."""
def test_group_with_none_coeff(self):
"""group() should accept coeff=None without errors."""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import group
M, K = 4, 32
A = torch.randn(M, K, device="cuda", dtype=torch.float32)
sorted_expert_idxs = torch.tensor([0, 1, 2, 3], device="cuda", dtype=torch.long)
# This should not raise a TypeError
Y = group(A, sorted_expert_idxs, coeff=None, fan_out=1)
assert Y.shape == (M, K)
def test_group_with_coeff(self):
"""group() should also work with actual coeff values."""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import group
M, K = 4, 32
A = torch.randn(M, K, device="cuda", dtype=torch.float32)
sorted_expert_idxs = torch.tensor([0, 1, 2, 3], device="cuda", dtype=torch.long)
coeff = torch.ones(M, device="cuda", dtype=torch.float32) * 0.5
Y = group(A, sorted_expert_idxs, coeff=coeff, fan_out=1)
assert Y.shape == (M, K)
# ============================================================================
# 6. Layer return value contracts
# ============================================================================
class TestLayerReturnValues:
"""Test that layer forward methods return the correct types."""
def test_hf_scatter_moe_returns_single_tensor(self):
"""HFScatterMoEGatedMLP.forward should return a single tensor, not a tuple."""
pytest.importorskip("triton")
# Verify the forward method signature and return annotation
import inspect
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
HFScatterMoEGatedMLP,
)
sig = inspect.signature(HFScatterMoEGatedMLP.forward)
# It's a staticmethod taking (self, layer_input)
params = list(sig.parameters.keys())
assert "self" in params
assert "layer_input" in params
def test_scatter_moe_gated_mlp_docstring_no_router_logits(self):
"""ScatterMoEGatedMLP.forward docstring should not mention router logits as return."""
pytest.importorskip("triton")
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
ScatterMoEGatedMLP,
)
docstring = ScatterMoEGatedMLP.forward.__doc__
assert docstring is not None
# The docstring should mention output tensor but NOT router logits
assert "Output tensor" in docstring or "output tensor" in docstring.lower()
assert "Router logits" not in docstring, (
"Docstring should not mention 'Router logits' in Returns section"
)

View File

@@ -7,7 +7,7 @@ import unittest
from transformers import LlamaTokenizer
from axolotl.utils.data import encode_streaming, md5
from axolotl.utils.trainer import drop_long_seq
from axolotl.utils.trainer import filter_sequences_by_length
from tests.hf_offline_utils import enable_hf_offline
@@ -70,17 +70,19 @@ class TestEncodePretraining(unittest.TestCase):
# -- single sequence --
# This should work
data = {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]}
drop_long_seq(data, 32, raise_on_drop=True)
filter_sequences_by_length(data, 32, raise_on_drop=True)
# This should return True, since data fits
dropped = drop_long_seq(data, 32)
dropped = filter_sequences_by_length(data, 32)
self.assertTrue(dropped)
# This should raise
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
self.assertRaises(
ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True
)
# This should return False, since data doesn't fit
dropped = drop_long_seq(data, 15)
dropped = filter_sequences_by_length(data, 15)
self.assertFalse(dropped)
# -- batch sequence --
@@ -91,13 +93,15 @@ class TestEncodePretraining(unittest.TestCase):
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
]
}
drop_long_seq(data, 32, raise_on_drop=True)
filter_sequences_by_length(data, 32, raise_on_drop=True)
# This should raise
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
self.assertRaises(
ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True
)
# This should keep the first but drop the second entry
dropped = drop_long_seq(data, 15)
dropped = filter_sequences_by_length(data, 15)
self.assertEqual(dropped, [True, False])

View File

@@ -0,0 +1,135 @@
"""Tests for revision_of_model being passed to tokenizer and processor loaders."""
from unittest.mock import MagicMock, patch
from transformers import PreTrainedTokenizerBase
from axolotl.utils.dict import DictDefault
class TestRevisionParameter:
"""Tests for revision_of_model being passed to tokenizer and processor loaders."""
@patch("axolotl.loaders.tokenizer.load_model_config")
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
@patch(
"axolotl.loaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches"
)
def test_load_tokenizer_passes_revision(
self, _mock_patches, mock_auto_tokenizer, _mock_load_config
):
mock_tokenizer = MagicMock()
mock_tokenizer.__class__.__name__ = "MockTokenizer"
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
cfg = DictDefault(
{
"tokenizer_config": "some-model",
"revision_of_model": "abc123",
}
)
from axolotl.loaders.tokenizer import load_tokenizer
load_tokenizer(cfg)
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
assert call_kwargs.kwargs.get("revision") == "abc123"
@patch("axolotl.loaders.tokenizer.load_model_config")
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
@patch(
"axolotl.loaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches"
)
def test_load_tokenizer_omits_revision_when_unset(
self, _mock_patches, mock_auto_tokenizer, _mock_load_config
):
mock_tokenizer = MagicMock()
mock_tokenizer.__class__.__name__ = "MockTokenizer"
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
cfg = DictDefault(
{
"tokenizer_config": "some-model",
}
)
from axolotl.loaders.tokenizer import load_tokenizer
load_tokenizer(cfg)
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
assert "revision" not in call_kwargs.kwargs
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
@patch("axolotl.loaders.tokenizer.is_local_main_process", return_value=True)
@patch("axolotl.loaders.tokenizer.barrier")
def test_modify_tokenizer_files_passes_revision(
self, _mock_barrier, _mock_main, mock_auto_tokenizer, temp_dir
):
mock_tokenizer = MagicMock()
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
from axolotl.loaders.tokenizer import modify_tokenizer_files
modify_tokenizer_files("some-model", {}, output_dir=temp_dir, revision="abc123")
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
assert call_kwargs.kwargs.get("revision") == "abc123"
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
@patch("axolotl.loaders.tokenizer.is_local_main_process", return_value=True)
@patch("axolotl.loaders.tokenizer.barrier")
def test_modify_tokenizer_files_defaults_revision_to_main(
self, _mock_barrier, _mock_main, mock_auto_tokenizer, temp_dir
):
mock_tokenizer = MagicMock()
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
from axolotl.loaders.tokenizer import modify_tokenizer_files
modify_tokenizer_files("some-model", {}, output_dir=temp_dir)
call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
assert call_kwargs.kwargs.get("revision") == "main"
@patch("axolotl.loaders.processor.AutoProcessor")
def test_load_processor_passes_revision(self, mock_auto_processor):
mock_processor = MagicMock()
mock_processor.size = {}
mock_auto_processor.from_pretrained.return_value = mock_processor
cfg = DictDefault(
{
"processor_config": "some-model",
"revision_of_model": "abc123",
"trust_remote_code": False,
}
)
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
from axolotl.loaders.processor import load_processor
load_processor(cfg, tokenizer)
call_kwargs = mock_auto_processor.from_pretrained.call_args
assert call_kwargs.kwargs.get("revision") == "abc123"
@patch("axolotl.loaders.processor.AutoProcessor")
def test_load_processor_omits_revision_when_unset(self, mock_auto_processor):
mock_processor = MagicMock()
mock_processor.size = {}
mock_auto_processor.from_pretrained.return_value = mock_processor
cfg = DictDefault(
{
"processor_config": "some-model",
"trust_remote_code": False,
}
)
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
from axolotl.loaders.processor import load_processor
load_processor(cfg, tokenizer)
call_kwargs = mock_auto_processor.from_pretrained.call_args
assert "revision" not in call_kwargs.kwargs

View File

@@ -0,0 +1,210 @@
"""Tests to verify that deduplication runs before dataset saving during preprocessing.
This addresses GitHub issue #2719: Save De-duplicated Set During Pre-processing.
"""
from unittest.mock import MagicMock, patch
from datasets import Dataset
from axolotl.utils.dict import DictDefault
class TestSFTSaveDeduplicatedBeforeSave:
"""Verify that in SFT data loading, deduplication occurs before saving."""
@patch("axolotl.utils.data.sft.save_preprocessed_dataset")
@patch("axolotl.utils.data.sft.generate_dataset_hash_from_config")
@patch("axolotl.utils.data.sft.deduplicate_and_log_datasets")
@patch("axolotl.utils.data.sft.merge_datasets")
@patch("axolotl.utils.data.sft._load_and_process_single_dataset")
@patch("axolotl.utils.data.sft.datasets_with_name_generator")
def test_dedup_called_before_save_sft(
self,
mock_datasets_gen,
mock_load_single,
mock_merge,
mock_dedup,
mock_gen_hash,
mock_save,
):
"""Deduplication should be called before save_preprocessed_dataset in SFT."""
from axolotl.utils.data.sft import _load_raw_datasets
# Set up mock data
dataset = Dataset.from_dict({"text": ["a", "b", "a"], "label": [1, 2, 1]})
deduped_dataset = Dataset.from_dict({"text": ["a", "b"], "label": [1, 2]})
mock_datasets_gen.return_value = [
DictDefault({"path": "test", "type": "alpaca"})
]
mock_load_single.return_value = (dataset, None)
mock_merge.return_value = dataset
mock_dedup.return_value = (deduped_dataset, None)
mock_gen_hash.return_value = "testhash"
cfg = DictDefault(
{
"skip_prepare_dataset": False,
"dataset_exact_deduplication": True,
"sequence_len": 1024,
"eval_sequence_len": None,
"sample_packing": False,
"is_preprocess": False,
"seed": 42,
"datasets": [{"path": "test", "type": "alpaca"}],
}
)
tokenizer = MagicMock()
tokenizer.name_or_path = "test-tokenizer"
# Track call order
call_order = []
mock_dedup.side_effect = lambda **kwargs: (
call_order.append("dedup") or (deduped_dataset, None)
)
mock_save.side_effect = lambda *args, **kwargs: call_order.append("save")
_load_raw_datasets(
cfg=cfg,
datasets_configs=cfg.datasets,
tokenizer=tokenizer,
split="train",
)
# Verify dedup was called
assert "dedup" in call_order, "Deduplication should have been called"
# Verify save was called
assert "save" in call_order, "Save should have been called"
# Verify dedup happened before save
assert call_order.index("dedup") < call_order.index("save"), (
"Deduplication must occur before saving the dataset"
)
@patch("axolotl.utils.data.sft.save_preprocessed_dataset")
@patch("axolotl.utils.data.sft.generate_dataset_hash_from_config")
@patch("axolotl.utils.data.sft.merge_datasets")
@patch("axolotl.utils.data.sft._load_and_process_single_dataset")
@patch("axolotl.utils.data.sft.datasets_with_name_generator")
def test_no_dedup_when_disabled_sft(
self,
mock_datasets_gen,
mock_load_single,
mock_merge,
mock_gen_hash,
mock_save,
):
"""Deduplication should not be called when dataset_exact_deduplication is False."""
from axolotl.utils.data.sft import _load_raw_datasets
dataset = Dataset.from_dict({"text": ["a", "b", "a"], "label": [1, 2, 1]})
mock_datasets_gen.return_value = [
DictDefault({"path": "test", "type": "alpaca"})
]
mock_load_single.return_value = (dataset, None)
mock_merge.return_value = dataset
mock_gen_hash.return_value = "testhash"
cfg = DictDefault(
{
"skip_prepare_dataset": False,
"dataset_exact_deduplication": False,
"sequence_len": 1024,
"eval_sequence_len": None,
"sample_packing": False,
"is_preprocess": False,
"seed": 42,
"datasets": [{"path": "test", "type": "alpaca"}],
}
)
tokenizer = MagicMock()
tokenizer.name_or_path = "test-tokenizer"
with patch("axolotl.utils.data.sft.deduplicate_and_log_datasets") as mock_dedup:
_load_raw_datasets(
cfg=cfg,
datasets_configs=cfg.datasets,
tokenizer=tokenizer,
split="train",
)
mock_dedup.assert_not_called()
class TestRLSaveDeduplicatedBeforeSave:
"""Verify that in RL data loading, deduplication occurs before saving."""
@patch.object(Dataset, "filter", lambda self, *args, **kwargs: self)
@patch("axolotl.utils.data.rl.save_preprocessed_dataset")
@patch("axolotl.utils.data.rl.generate_dataset_hash_from_config")
@patch("axolotl.utils.data.rl.deduplicate_and_log_datasets")
@patch("axolotl.utils.data.rl.merge_datasets")
@patch("axolotl.utils.data.rl.load_dataset_with_config")
@patch("axolotl.utils.data.rl.datasets_with_name_generator")
@patch("axolotl.utils.data.rl.load_tokenizer")
def test_dedup_called_before_save_rl(
self,
mock_load_tokenizer,
mock_datasets_gen,
mock_load_dataset,
mock_merge,
mock_dedup,
mock_gen_hash,
mock_save,
):
"""Deduplication should be called before save_preprocessed_dataset in RL."""
from axolotl.utils.data.rl import _load_split
dataset = Dataset.from_dict(
{
"prompt": ["hi", "bye", "hi"],
"chosen": ["a", "b", "a"],
"rejected": ["c", "d", "c"],
}
)
deduped_dataset = Dataset.from_dict(
{
"prompt": ["hi", "bye"],
"chosen": ["a", "b"],
"rejected": ["c", "d"],
}
)
mock_datasets_gen.return_value = [DictDefault({"path": "test", "type": None})]
mock_load_dataset.return_value = dataset
mock_merge.return_value = dataset
mock_dedup.return_value = (deduped_dataset, None)
mock_gen_hash.return_value = "testhash"
tokenizer = MagicMock()
tokenizer.name_or_path = "test-tokenizer"
mock_load_tokenizer.return_value = tokenizer
cfg = DictDefault(
{
"skip_prepare_dataset": False,
"dataset_exact_deduplication": True,
"sequence_len": 1024,
"rl": "dpo",
"datasets": [{"path": "test", "type": None}],
"hf_use_auth_token": False,
"dataset_num_proc": 1,
"is_preprocess": False,
}
)
call_order = []
mock_dedup.side_effect = lambda **kwargs: (
call_order.append("dedup") or (deduped_dataset, None)
)
mock_save.side_effect = lambda *args, **kwargs: call_order.append("save")
_load_split(cfg, split="train")
assert "dedup" in call_order, "Deduplication should have been called"
assert "save" in call_order, "Save should have been called"
assert call_order.index("dedup") < call_order.index("save"), (
"Deduplication must occur before saving the dataset"
)

View File

@@ -116,6 +116,7 @@ class TestTokenizers:
tokenizer.decode([128041, 128042]) == "RANDOM_OVERRIDE_1RANDOM_OVERRIDE_2"
)
@pytest.mark.skip("FIXME slow test sdist py3.11 + torch2.8.0")
@enable_hf_offline
def test_added_tokens_overrides_gemma3(self, temp_dir):
cfg = DictDefault(

View File

@@ -0,0 +1,545 @@
"""
Unit tests for data utility functions
"""
import unittest
from unittest.mock import MagicMock
from datasets import Dataset
from axolotl.utils.data.utils import handle_long_seq_in_dataset
from axolotl.utils.dict import DictDefault
class TestHandleLongSeqInDataset(unittest.TestCase):
"""
Test class for handle_long_seq_in_dataset function
"""
def test_drop_strategy_removes_long_sequences(self):
"""Test that 'drop' strategy removes sequences longer than sequence_len"""
# Create dataset with mixed length sequences
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3], # length 3 - keep
[1, 2, 3, 4, 5], # length 5 - keep
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 - drop
[1, 2], # length 2 - keep
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should have dropped the sequence with length 11
self.assertEqual(len(result), 3)
self.assertEqual(len(result[0]["input_ids"]), 3)
self.assertEqual(len(result[1]["input_ids"]), 5)
self.assertEqual(len(result[2]["input_ids"]), 2)
def test_drop_strategy_is_default(self):
"""Test that 'drop' is the default strategy when not specified"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 - should drop
]
}
)
cfg = DictDefault(
{
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should have dropped the long sequence
self.assertEqual(len(result), 1)
def test_truncate_strategy_truncates_long_sequences(self):
"""Test that 'truncate' strategy truncates sequences to sequence_len"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3], # length 3 - keep as is
[
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
], # length 12 - truncate to 10
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "truncate",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should have 2 samples
self.assertEqual(len(result), 2)
# First sample unchanged
self.assertEqual(len(result[0]["input_ids"]), 3)
# Second sample truncated to 10
self.assertEqual(len(result[1]["input_ids"]), 10)
self.assertEqual(result[1]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
def test_truncate_strategy_truncates_all_auxiliary_fields(self):
"""Test that truncation applies to all auxiliary fields consistently"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
],
"attention_mask": [
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
],
"labels": [
[-100, -100, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
],
"position_ids": [
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
],
}
)
cfg = DictDefault(
{
"excess_length_strategy": "truncate",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# All fields should be truncated to 10
self.assertEqual(len(result[0]["input_ids"]), 10)
self.assertEqual(len(result[0]["attention_mask"]), 10)
self.assertEqual(len(result[0]["labels"]), 10)
self.assertEqual(len(result[0]["position_ids"]), 10)
# Verify content is correct
self.assertEqual(result[0]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
self.assertEqual(result[0]["attention_mask"], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
self.assertEqual(result[0]["labels"], [-100, -100, 3, 4, 5, 6, 7, 8, 9, 10])
self.assertEqual(result[0]["position_ids"], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
def test_raise_strategy_raises_on_long_sequences(self):
"""Test that 'raise' strategy raises ValueError when encountering long sequences"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 - should raise
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "raise",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
with self.assertRaises(ValueError):
handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
def test_min_sequence_len_filters_short_sequences(self):
"""Test that sequences shorter than min_sample_len are filtered out"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1], # length 1 - drop (< min_sample_len=3)
[1, 2], # length 2 - drop
[1, 2, 3], # length 3 - keep
[1, 2, 3, 4, 5], # length 5 - keep
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 3,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should only keep sequences with length >= 3
self.assertEqual(len(result), 2)
self.assertEqual(len(result[0]["input_ids"]), 3)
self.assertEqual(len(result[1]["input_ids"]), 5)
def test_dataset_without_input_ids_column(self):
"""Test that datasets without 'input_ids' column are returned unchanged"""
dataset = Dataset.from_dict(
{
"chosen": [1, 2, 3],
"rejected": [4, 5, 6],
}
)
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 2,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Dataset should be unchanged
self.assertEqual(len(result), len(dataset))
self.assertListEqual(list(result.column_names), ["chosen", "rejected"])
def test_truncate_filters_short_before_truncating(self):
"""Test that truncate strategy filters short sequences before truncating long ones
This is important for efficiency - we should not waste time truncating
sequences that will be filtered out anyway.
"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1], # length 1 - filter out first
[1, 2, 3], # length 3 - keep, no truncation needed
[
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
], # length 12 - keep and truncate
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "truncate",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should have filtered out the first (short) sequence
self.assertEqual(len(result), 2)
# Second sample unchanged
self.assertEqual(len(result[0]["input_ids"]), 3)
# Third sample truncated to 10
self.assertEqual(len(result[1]["input_ids"]), 10)
def test_case_insensitive_strategy(self):
"""Test that excess_length_strategy is case-insensitive"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "TRUNCATE", # uppercase
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should still truncate
self.assertEqual(len(result[0]["input_ids"]), 10)
def test_raise_strategy_silently_drops_short_sequences(self):
"""Test that 'raise' strategy drops short sequences without raising"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1], # length 1 - too short, should be dropped silently
[1, 2, 3, 4, 5], # length 5 - keep
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "raise",
"min_sample_len": 3,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
# Should NOT raise, just silently drop the short sequence
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
self.assertEqual(len(result), 1)
self.assertEqual(len(result[0]["input_ids"]), 5)
def test_drop_boundary_sequence_equal_to_sequence_len(self):
"""Test that drop strategy keeps sequences with length exactly equal to sequence_len"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # length 10 == sequence_len
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 > sequence_len
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Exactly equal should be kept, one over should be dropped
self.assertEqual(len(result), 1)
self.assertEqual(len(result[0]["input_ids"]), 10)
def test_truncate_boundary_sequence_equal_to_sequence_len(self):
"""Test that truncate strategy leaves sequences with length exactly equal to sequence_len unchanged"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # length 10 == sequence_len
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "truncate",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should be unchanged - not truncated
self.assertEqual(len(result), 1)
self.assertEqual(result[0]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
def test_empty_dataset(self):
"""Test that an empty dataset is handled gracefully"""
dataset = Dataset.from_dict({"input_ids": []})
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
self.assertEqual(len(result), 0)
def test_all_sequences_dropped_returns_empty_dataset(self):
"""Test that dropping all sequences results in an empty dataset"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1], # too short
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # too long
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 5,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
self.assertEqual(len(result), 0)
def test_iterable_dataset_skips_processing(self):
"""Test that streaming datasets (column_names is None) are returned unchanged.
The skip check in _should_skip_processing triggers when column_names is
None, which happens with true streaming datasets loaded via
load_dataset(..., streaming=True).
"""
mock_dataset = MagicMock()
mock_dataset.column_names = None
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(mock_dataset, sequence_len=10, cfg=cfg)
# Should be returned unchanged (same object)
self.assertIs(result, mock_dataset)
def test_truncate_with_partial_auxiliary_fields(self):
"""Test truncation when only some auxiliary fields are present"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
],
"labels": [
[-100, -100, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
],
# No attention_mask or position_ids
}
)
cfg = DictDefault(
{
"excess_length_strategy": "truncate",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
self.assertEqual(len(result[0]["input_ids"]), 10)
self.assertEqual(len(result[0]["labels"]), 10)
self.assertEqual(result[0]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
self.assertEqual(result[0]["labels"], [-100, -100, 3, 4, 5, 6, 7, 8, 9, 10])
# Confirm no extra columns were introduced
self.assertListEqual(sorted(result.column_names), ["input_ids", "labels"])
def test_min_sample_len_defaults_to_two_when_not_set(self):
"""Test that min_sample_len defaults to 2 when not specified in config"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1], # length 1 - should be dropped (< default 2)
[1, 2], # length 2 - should be kept (>= default 2)
[1, 2, 3], # length 3 - should be kept
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "drop",
# min_sample_len not set
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
self.assertEqual(len(result), 2)
self.assertEqual(len(result[0]["input_ids"]), 2)
self.assertEqual(len(result[1]["input_ids"]), 3)
def test_invalid_strategy_falls_through_to_drop(self):
"""Test that an unrecognized strategy value falls through to drop behavior"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3], # keep
[
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
], # length 11 - should be dropped
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "not_a_real_strategy",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should behave like 'drop'
self.assertEqual(len(result), 1)
self.assertEqual(len(result[0]["input_ids"]), 3)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,142 @@
"""Tests for MoE expert quantization config validation and PEFT patch idempotency."""
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture()
def gpu_caps():
return {"compute_capability": "sm_89", "bf16": True, "n_gpu": 1, "n_node": 1}
@pytest.fixture()
def env_caps():
return {"torch_version": "2.7.0"}
class TestQuantizeMoeExpertsValidation:
"""Test suite for quantize_moe_experts config validator."""
def test_requires_adapter(self, min_base_cfg, gpu_caps, env_caps):
"""quantize_moe_experts without adapter should fail."""
cfg = (
DictDefault(
quantize_moe_experts=True,
)
| min_base_cfg
)
with pytest.raises(ValueError, match="requires adapter"):
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
def test_requires_quantization(self, min_base_cfg, gpu_caps, env_caps):
"""quantize_moe_experts without load_in_4bit/8bit should fail."""
cfg = (
DictDefault(
quantize_moe_experts=True,
adapter="lora",
)
| min_base_cfg
)
with pytest.raises(ValueError, match="requires load_in_4bit or load_in_8bit"):
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
def test_valid_qlora_4bit(self, min_base_cfg, gpu_caps, env_caps):
"""quantize_moe_experts with qlora + 4bit should pass."""
cfg = (
DictDefault(
quantize_moe_experts=True,
adapter="qlora",
load_in_4bit=True,
)
| min_base_cfg
)
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
assert result["quantize_moe_experts"] is True
def test_valid_lora_8bit(self, min_base_cfg, gpu_caps, env_caps):
"""quantize_moe_experts with lora + 8bit should pass."""
cfg = (
DictDefault(
quantize_moe_experts=True,
adapter="lora",
load_in_8bit=True,
)
| min_base_cfg
)
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
assert result["quantize_moe_experts"] is True
def test_false_skips_validation(self, min_base_cfg, gpu_caps, env_caps):
"""quantize_moe_experts=false should not check adapter/quantization."""
cfg = (
DictDefault(
quantize_moe_experts=False,
)
| min_base_cfg
)
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
assert result["quantize_moe_experts"] is False
def test_default_is_false(self, min_base_cfg, gpu_caps, env_caps):
"""quantize_moe_experts should default to false."""
cfg = DictDefault({}) | min_base_cfg
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
assert result["quantize_moe_experts"] is False
class TestLoraTargetParametersDropout:
"""Test that lora_dropout must be 0 when lora_target_parameters is set."""
def test_rejects_nonzero_dropout(self, min_base_cfg):
"""lora_dropout > 0 with lora_target_parameters should fail."""
cfg = (
DictDefault(
adapter="lora",
lora_target_parameters=["mlp.experts.gate_up_proj"],
lora_dropout=0.1,
load_in_8bit=True,
)
| min_base_cfg
)
with pytest.raises(ValueError, match="lora_dropout must be 0"):
validate_config(cfg)
def test_zero_dropout_passes(self, min_base_cfg):
"""lora_dropout=0 with lora_target_parameters should pass."""
cfg = (
DictDefault(
adapter="lora",
lora_target_parameters=["mlp.experts.gate_up_proj"],
lora_dropout=0.0,
load_in_8bit=True,
)
| min_base_cfg
)
result = validate_config(cfg)
assert result["lora_dropout"] == 0.0
class TestPeftPatchIdempotency:
"""Test that patch_peft_target_parameters_matching is idempotent."""
def test_double_call_does_not_stack_wrappers(self):
"""Calling patch twice should not double-wrap _inject_parameters."""
from peft.tuners.tuners_utils import BaseTuner
from axolotl.monkeypatch.moe_quant import (
patch_peft_target_parameters_matching,
)
original = BaseTuner._inject_parameters
try:
patch_peft_target_parameters_matching()
first_patched = BaseTuner._inject_parameters
patch_peft_target_parameters_matching()
second_patched = BaseTuner._inject_parameters
# Should be same function, not double-wrapped
assert first_patched is second_patched
finally:
BaseTuner._inject_parameters = original
patch_peft_target_parameters_matching._axolotl_patched = False

View File

@@ -0,0 +1,149 @@
"""Tests for Mistral3Processor with transformers v5 ProcessorMixin integration"""
from unittest.mock import MagicMock
import pytest
import torch
from transformers.feature_extraction_utils import BatchFeature
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer
@pytest.fixture()
def mock_tokenizer():
"""Create a mock HFMistralTokenizer that passes v5 ProcessorMixin isinstance checks."""
return MagicMock(spec=HFMistralTokenizer)
@pytest.fixture()
def processor(mock_tokenizer):
return Mistral3Processor(tokenizer=mock_tokenizer)
class TestMistral3ProcessorInit:
def test_tokenizer_is_set(self, processor, mock_tokenizer):
assert processor.tokenizer is mock_tokenizer
def test_chat_template_is_none(self, processor):
assert processor.chat_template is None
def test_audio_tokenizer_is_none(self, processor):
assert processor.audio_tokenizer is None
class TestApplyChatTemplateTokenized:
"""Test apply_chat_template with tokenize=True, return_dict=True"""
@pytest.fixture()
def batched_conversations(self):
return [
[
{"role": "user", "content": "Describe this image."},
{"role": "assistant", "content": "It is red."},
],
[
{"role": "user", "content": "What is this?"},
{"role": "assistant", "content": "A cat."},
],
]
def test_returns_batch_feature_with_pixel_values(
self, processor, mock_tokenizer, batched_conversations
):
pixel_values = torch.randn(2, 3, 224, 224, dtype=torch.float64)
mock_tokenizer.apply_chat_template.return_value = {
"input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]),
"attention_mask": torch.tensor([[1, 1, 1], [1, 1, 1]]),
"pixel_values": pixel_values,
}
result = processor.apply_chat_template(
batched_conversations, tokenize=True, return_dict=True
)
assert isinstance(result, BatchFeature)
assert "pixel_values" in result
assert "image_sizes" in result
assert result["pixel_values"].dtype == torch.float32
assert result["image_sizes"].shape == (2, 2)
assert result["image_sizes"][0].tolist() == [224, 224]
def test_returns_batch_feature_without_pixel_values(
self, processor, mock_tokenizer, batched_conversations
):
mock_tokenizer.apply_chat_template.return_value = {
"input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]),
"attention_mask": torch.tensor([[1, 1, 1], [1, 1, 1]]),
}
result = processor.apply_chat_template(
batched_conversations, tokenize=True, return_dict=True
)
assert isinstance(result, BatchFeature)
assert "input_ids" in result
assert "image_sizes" not in result
class TestApplyChatTemplateNotTokenized:
def test_single_conversation_returns_unwrapped(self, processor, mock_tokenizer):
"""Single conversation (not batched) should return unwrapped result."""
single_conversation = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
]
mock_tokenizer.apply_chat_template.return_value = [
"<s>[INST]Hello[/INST]Hi</s>"
]
result = processor.apply_chat_template(
single_conversation, tokenize=False, return_dict=False
)
assert result == "<s>[INST]Hello[/INST]Hi</s>"
def test_batched_conversations_returns_list(self, processor, mock_tokenizer):
batched = [
[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
],
[
{"role": "user", "content": "Bye"},
{"role": "assistant", "content": "Bye"},
],
]
mock_tokenizer.apply_chat_template.return_value = ["text1", "text2"]
result = processor.apply_chat_template(
batched, tokenize=False, return_dict=False
)
assert result == ["text1", "text2"]
class TestCall:
def test_delegates_to_tokenizer(self, processor, mock_tokenizer):
mock_tokenizer.return_value = {
"input_ids": [1, 2, 3],
"attention_mask": [1, 1, 1],
}
result = processor("Hello world")
mock_tokenizer.assert_called_once()
assert isinstance(result, BatchFeature)
class TestReturnTensorsValidation:
def test_rejects_non_pt_return_tensors(self, processor):
conversation = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
]
with pytest.raises(ValueError, match=r"only supports.*return_tensors='pt'"):
processor.apply_chat_template(
conversation, tokenize=True, return_dict=True, return_tensors="np"
)