Compare commits

...

27 Commits

Author SHA1 Message Date
Wing Lian
8495c79fb1 properly handles kernels repo type 2026-04-23 14:56:16 -04:00
Wing Lian
9a0d3016df first pass at build and deploy scattermoe-lora kernel 2026-04-22 01:10:01 -04:00
thad0ctor
e562e149ce fix: [gemma4] fix VRAM leak in hybrid FA2+SDPA (hybrid attentiuon) path under activation check… (#3611)
* [gemma4] fix VRAM leak in hybrid FA2+SDPA path under activation checkpointing

Route shared_kv_states through a thread-local side channel instead of the
decoder-layer kwargs so the checkpoint partial never references the dict.

HF's Gemma4TextModel.forward passes shared_kv_states (a mutable dict used
for cross-layer K/V sharing) as a kwarg to every decoder_layer call.
GradientCheckpointingLayer.__call__ then forms
partial(super().__call__, **kwargs), and whichever checkpoint runs
(axolotl's CPU_Offloaded_Gradient_Checkpointer or torch's stock
checkpoint) captures that partial. The partial holds a reference to the
dict, which holds the K/V tensors produced by store_full_length_kv
layers. Those tensors stay pinned for the full duration of backward, and
delayed ref-cycle cleanup in torch's caching allocator under FSDP2 +
activation checkpointing bleeds the residual across steps.

Observed symptom: VRAM climbs ~0.47 GiB/step from a 42 GiB baseline,
OOMs around step 73 (~94 GiB peak) on Gemma-4 31B multimodal with
gemma4_hybrid_attn_impl: true. Independent of seq len / image size.
All-flex-attention path is flat but ~22x slower.

Violated invariant: anything crossing an activation-checkpoint boundary
must be a tensor (refcounted by autograd) or plain Python data -- never
a mutable container holding tensor references.

Fix (all in src/axolotl/monkeypatch/models/gemma4/fused_attn.py):
  * threading.local() store with _get/_set_shared_kv_states helpers
  * _patch_decoder_layer_call(): monkeypatches
    Gemma4TextDecoderLayer.__call__ to pop shared_kv_states from kwargs
    and stash it in TLS before delegating to GradientCheckpointingLayer.
    The partial formed downstream no longer references the dict.
  * fused_forward reads TLS first, falls back to kwarg for callers that
    bypass the patched __call__ (e.g. direct attention invocation).
  * wired into patch_gemma4_fused_attn; idempotent via a sentinel.

TLS is overwritten on each new step's first decoder-layer call, so the
previous step's dict is released promptly. No changes to hybrid dispatch,
FSDP wrap policy, or any config behaviour. Works for hybrid, flex, and
eager paths.

Introduced by PR #3598 (commit b8358aa5).

* Coderabbit comment: gemma4: clear TLS unconditionally in decoder-layer patched __call__

  Overwrite the thread-local shared_kv_states store on every invocation
  (including with None) instead of only when the kwarg is present.

  The previous conditional write left stale dicts in TLS on any path that
  reaches Gemma4TextDecoderLayer.__call__ without a shared_kv_states
  kwarg — e.g. generation, eval hooks, or future HF refactors that make
  the kwarg optional. fused_forward would then silently consume a prior
  step's K/V dict instead of falling back to its own kwarg path.

  Unconditional write makes the invariant in the surrounding comment
  ("TLS is overwritten on each new step's first decoder-layer call, so
  the previous step's dict is released promptly") actually hold.

  No behavior change for the training happy path, which always passes
  the kwarg. Addresses CodeRabbit review on PR #3611

* fix: swap threading.local() for module-level store so autograd worker   threads see shared_kv_states during backward recompute

Previous commits fixed memory leak on 31B but caused type error with MOE Gemma4 variants - this fixes that:

PR 3611's TLS variant only works when recompute runs on the same thread
  that set TLS during forward. PyTorch's C++ autograd engine
  (_engine_run_backward) spawns per-device worker threads to dispatch
  backward, and HF-Trainer gradient_checkpointing (stock
  torch.utils.checkpoint, non-reentrant / saved-tensor-hooks) fires
  unpack_hook -> recompute_fn on those worker threads. TLS set on the main
  thread during forward is invisible there, so _get_shared_kv_states()
  returns None and the consumer-layer lookup crashes with
  "'NoneType' object is not subscriptable" at
  fused_attn.py:97 (shared_kv_states[self.kv_shared_layer_index]).

  A plain module-level dict is visible to all threads in the process.
  Lifecycle is identical: the slot is overwritten each forward, releasing
  the previous step's dict and allowing its K/V tensors to be GC'd, so
  the original VRAM-leak fix still holds under FSDP2 AC too.

* scope gemma4 shared_kv_states side channel to checkpointed training

Update PR #3611 with gate for checkpointed training to avoid regressions across async flows.

Added unit tests for kwargs pop, store-clear regression, and flag gating. Condensed verbose comments

* add gemma4 cross-thread visibility test for shared_kv_states store

Additional regression test for MoE gemma4 variants - asserts the module-level store is readable from threads other than the one that set it in response to previously observed 'NoneType' error

* fix logger

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-21 17:49:58 -04:00
NanoCode012
9de5b76336 feat: move to uv first (#3545)
* feat: move to uv first

* fix: update doc to uv first

* fix: merge dev/tests into uv pyproject

* fix: update docker docs to match current config

* fix: migrate examples to readme

* fix: add llmcompressor to conflict

* feat: rec uv sync with lockfile for dev/ci

* fix: update docker docs to clarify how to use uv images

* chore: docs

* fix: use system python, no venv

* fix: set backend cpu

* fix: only set for installing pytorch step

* fix: remove unsloth kernel and installs

* fix: remove U in tests

* fix: set backend in deps too

* chore: test

* chore: comments

* fix: attempt to lock torch

* fix: workaround torch cuda and not upgraded

* fix: forgot to push

* fix: missed source

* fix: nightly upstream loralinear config

* fix: nightly phi3 long rope not work

* fix: forgot commit

* fix: test phi3 template change

* fix: no more requirements

* fix: carry over changes from new requirements to pyproject

* chore: remove lockfile per discussion

* fix: set match-runtime

* fix: remove unneeded hf hub buildtime

* fix: duplicate cache delete on nightly

* fix: torchvision being overridden

* fix: migrate to uv images

* fix: leftover from merge

* fix: simplify base readme

* fix: update assertion message to be clearer

* chore: docs

* fix: change fallback for cicd script

* fix: match against main exactly

* fix: peft 0.19.1 change

* fix: e2e test

* fix: ci

* fix: e2e test
2026-04-21 10:16:03 -04:00
Wing Lian
323da791eb bump transformers to 5.5.4 and trl to latest 1.1.0 (#3603)
* bump transformers to 5.5.4 and trl to latest 1.1.0

* more upgrades

* update peft too

* adapt lora_merge to peft 0.19 layer config API

PEFT 0.19 requires a LoraConfig object on Linear/ParamWrapper/Conv
layer constructors and moved use_rslora, use_dora, fan_in_fan_out,
lora_dropout, and lora_bias into that config. Build the config
per branch in _build_peft_layer_and_get_delta so the merge utility
works with the upgraded peft.

* allow lora_dropout on mixed attention+MoE configs under peft 0.19

PEFT 0.19's convert_peft_config_for_transformers auto-remaps old MoE
target_modules (w1/w2/w3 on Mixtral, etc.) into target_parameters for
transformers v5's fused 3D expert Parameters. Those targets get wrapped
with ParamWrapper, which rejects lora_dropout != 0 because the 3D
einsum can't factor dropout out of lora_B(lora_A(dropout(x))).

Monkeypatch ParamWrapper.__init__ to internally use a copy of the
LoraConfig with lora_dropout=0, so its dropout slot becomes nn.Identity
while the shared config still delivers real dropout to sibling Linear
LoRA layers (attention q/k/v/o). A probe runs the same conversion on a
deep copy to detect the situation and emit a warning before patching.
2026-04-15 09:27:03 -04:00
NanoCode012
6990478163 fix: rename model to adapter_model for fsdp sharded final model (#3585)
* fix: rename model to adapter_model for fsdp sharded final model

* fix: follow upstream transformer shard size

* fix: handle multiple model files

* fix redundant condition, tighten to safetensors, keep shard size small

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-12 20:51:30 -04:00
ゆり
63a58cfec1 feat: support excess_length_strategy for RL trainers (#3578) [skip ci]
* feat: support excess_length_strategy for RL trainers

Previously, RL data loading always dropped sequences exceeding
sequence_len. This adds support for the existing `excess_length_strategy`
config option (`drop`, `truncate`, `raise`) in RL training pipelines,
matching the behavior already available for SFT.

- `drop` (default): unchanged behavior, filters out long samples
- `truncate`: tokenizes text components, truncates responses to fit
  within sequence_len while preserving the full prompt, then decodes
  back to text. Handles DPO/IPO/ORPO/SIMPO and KTO datasets.
- `raise`: raises ValueError if any sample exceeds sequence_len

Closes #3547

* improve RL truncation strategy robustness and performance

---------

Co-authored-by: yurekami <yurekami@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-12 20:51:10 -04:00
madScientist10
3985ec2f67 feat: add FineGrainedFP8Config support for model quantization (#3587) [skip ci]
Allow loading FP8-quantized models (e.g. Mistral-Small-4-119B) with
FineGrainedFP8Config and optional dequantize kwarg for full fine-tuning.

Made-with: Cursor
2026-04-12 20:50:37 -04:00
Joaquin Hui
a44edda6d7 Skip redundant evaluation when resuming from checkpoint (#3575) [skip ci]
* Skip redundant evaluation when resuming from checkpoint

* add condition check for adding callback

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-12 20:50:15 -04:00
Wing Lian
66c3e5a3fd better handling of dora merge on Conv layers in Qwen 3.5 (#3599)
* better handling of dora merge on Conv layers in Qwen 3.5

* address issues from code review

* stricter efficient merges for dora since we now have meta model to reference
2026-04-12 10:57:45 -04:00
Wing Lian
b8358aa5ab [gemma4] use mixed Flash Attention and SDPA and add fused RMSNorm+RoPE Triton kernels (#3598) 2026-04-12 10:29:55 -04:00
Joaquin Hui
e079cf16a2 qwen3_5.jinja: handle list content on system messages (#3595) [skip ci]
* qwen3_5.jinja: handle list content on system messages

The system message branch used string concatenation on
messages[0].content, which breaks when the first system message uses
the OpenAI-style list-of-parts format that multimodal datasets require.
User and assistant branches already handle both string and list content,
but the system branch did not.

Check whether content is a string and fall back to iterating over parts
when it is a list, matching the pattern used for user messages.

Fixes #3590

* Address pr for other content types

---------

Co-authored-by: Joaquin Hui Gomez <joaquinhuigomez@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-12 00:58:58 -04:00
Wing Lian
e2f69828d2 [fix][fsdp2] clone sharded param so original full size shard can be gc'ed (#3597) [skip ci] 2026-04-11 20:22:35 -04:00
Wing Lian
122b50bad6 pre-cache the eot token ids rather than on each iteration (#3594) [skip ci] 2026-04-11 20:05:21 -04:00
Wing Lian
e77a185e86 upgrade transformers to use v5.5.3 (#3593) 2026-04-10 17:08:14 -04:00
Wing Lian
29fa4dedbb Gemma4 fixes and profiler (#3591) 2026-04-10 16:46:17 -04:00
Wing Lian
315cdeede9 handle trainable/masked spans in content and reasoning content (#3592) 2026-04-10 14:11:10 -04:00
NanoCode012
e7a6a5b529 fix: move warning after we've set any overrides (#3589) [skip ci] 2026-04-10 13:00:47 -04:00
NanoCode012
bfb4da1d25 fix: document jinja2 file path support (#3588) [skip ci] 2026-04-10 13:00:26 -04:00
floaty3
4dfa0a59b2 Add uninstall command to cut_cross_entropy import message (#3583) [skip ci] 2026-04-10 13:00:07 -04:00
Wing Lian
4ef608dda3 fix ddp/fsdp w gemma4 (#3584)
* fix ddp/fsdp w gemma4

* address pr comments

* activation offloading fix and update agent docs for gemma4
2026-04-09 20:02:36 -07:00
NanoCode012
7daf7d96f1 fix: regex for unfrozen language tower (#3586) [skip ci]
* fix: regex for unfrozen language tower

* fix: other leftover regex
2026-04-08 08:18:11 -07:00
Wing Lian
7c56809c7f use vllm 0.19.0 for torch 2.10.0 (#3582) 2026-04-07 08:09:49 -07:00
NanoCode012
149178ddb7 chore: cleanup post release v0.16 (#3577)
* fix: remove unneeded debug log

* fix: cleanup

* feat: add dense gemma config and cleanup

* feat: add cce support

* update notes and set torch compile

* fix patch for new number of return vals

* fixes for gemma4

* fix packing bug

* use updated cce for mm

* fix: pass in kv cache func when avail for transformers 5.5

* feat: update examples with flex variant and readme

* gemma4 lora attention kernels

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-06 10:10:52 -07:00
NanoCode012
dc638e723f fix(config): add cce and liger to nemotron-h example (#3573) [skip ci] 2026-04-06 10:10:25 -07:00
Wing Lian
6f15da4cac make it easier for agents to discover docs (#3579) [skip ci]
* make it easier for agents to discover docs

* fixup pr comments
2026-04-06 10:00:55 -07:00
Maxime
900eec7988 Fix DO_NOT_TRACK not being correctly handled (#3580)
* Fix DO_NOT_TRACK not being correctly handled

* add unit tests and lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-04 05:16:58 -04:00
115 changed files with 7346 additions and 1722 deletions

View File

@@ -31,7 +31,10 @@ PRs are **greatly welcome**!
Please run below to setup env
```bash
pip3 install -r requirements-dev.txt -r requirements-tests.txt
# Install axolotl + dev and test dependencies from lockfile
export UV_TORCH_BACKEND=cu128 # or cu130
uv sync --extra flash-attn --extra deepspeed --group dev --group test
source .venv/bin/activate
pre-commit install
# test

View File

@@ -6,7 +6,7 @@ on:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- '**.py'
- 'requirements.txt'
- 'pyproject.toml'
- '.github/workflows/*.yml'
- "*.[q]md"
- "examples/**/*.y[a]?ml"

View File

@@ -3,17 +3,15 @@ name: docker-multigpu-tests-biweekly
on:
pull_request:
paths:
- 'tests/e2e/multigpu/**.py'
- 'requirements.txt'
- 'setup.py'
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
- 'scripts/cutcrossentropy_install.py'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
- 'src/axolotl/utils/distributed.py'
- "tests/e2e/multigpu/**.py"
- "pyproject.toml"
- ".github/workflows/multi-gpu-e2e.yml"
- "scripts/cutcrossentropy_install.py"
- "src/axolotl/core/trainers/mixins/sequence_parallel.py"
- "src/axolotl/utils/distributed.py"
workflow_dispatch:
schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
- cron: "0 0 * * 1,4" # Runs at 00:00 UTC every monday & thursday
# Cancel jobs on the same ref if a new one is triggered
concurrency:
@@ -33,19 +31,19 @@ jobs:
fail-fast: false
matrix:
include:
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras: "fbgemm-gpu"
# num_gpus: 2
# dockerfile: "Dockerfile-uv.jinja"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras: "fbgemm-gpu"
# num_gpus: 2
# dockerfile: "Dockerfile-uv.jinja"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
# axolotl_extras: fbgemm-gpu
# axolotl_extras: fbgemm-gpu
num_gpus: 2
- cuda: 128
cuda_version: 12.8.1
@@ -53,7 +51,6 @@ jobs:
pytorch: 2.10.0
axolotl_extras: "fbgemm-gpu"
num_gpus: 2
dockerfile: "Dockerfile-uv.jinja"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:
@@ -75,7 +72,7 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -8,6 +8,9 @@ on:
permissions: {}
env:
UV_SYSTEM_PYTHON: "1"
jobs:
setup_release:
name: Create Release
@@ -41,11 +44,15 @@ jobs:
with:
python-version: "3.11"
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Install dependencies
run: |
pip3 install wheel packaging==26.0
pip3 install --no-build-isolation -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt
uv pip install wheel packaging
uv pip install --no-build-isolation -e .
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
- name: Extract tag name
id: tag

View File

@@ -2,15 +2,18 @@ name: Tests Nightly against upstream main
on:
workflow_dispatch:
schedule:
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
- cron: "0 0 * * *" # Runs at 00:00 UTC every day
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- '.github/workflows/tests-nightly.yml'
- ".github/workflows/tests-nightly.yml"
permissions:
contents: read
env:
UV_SYSTEM_PYTHON: "1"
jobs:
pre-commit:
name: pre-commit
@@ -20,7 +23,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: 'pip' # caching pip dependencies
cache: "pip" # caching pip dependencies
- uses: pre-commit/action@v3.0.1
env:
SKIP: no-commit-to-branch
@@ -43,7 +46,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.9.1", "2.10.0"]
timeout-minutes: 20
@@ -61,36 +64,34 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==78.1.1 wheel
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }} torchvision
- name: Update requirements.txt
run: |
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt
uv pip install torch==${{ matrix.pytorch_version }} torchvision
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
- name: Install dependencies
run: |
pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
uv pip install --no-build-isolation -e . --override /tmp/torch-pin.txt
python scripts/cutcrossentropy_install.py --uv | sh
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
- name: Override with nightly HF packages
run: |
uv pip install --no-deps \
"transformers @ git+https://github.com/huggingface/transformers.git@main" \
"peft @ git+https://github.com/huggingface/peft.git@main" \
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
"trl @ git+https://github.com/huggingface/trl.git@main" \
"datasets @ git+https://github.com/huggingface/datasets.git@main"
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
- name: Ensure axolotl CLI was installed
run: |
@@ -102,9 +103,6 @@ jobs:
pytest -v --durations=10 tests/patched/
pytest -v --durations=10 tests/cli/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
@@ -136,7 +134,6 @@ jobs:
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
nightly_build: "true"
steps:
- name: Checkout
@@ -157,7 +154,7 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:

View File

@@ -6,21 +6,19 @@ on:
branches:
- "main"
paths:
- '**.py'
- 'requirements.txt'
- '.github/workflows/*.yml'
- 'requirements-tests.txt'
- 'cicd/cicd.sh'
- 'cicd/Dockerfile.jinja'
- "**.py"
- "pyproject.toml"
- ".github/workflows/*.yml"
- "cicd/cicd.sh"
- "cicd/Dockerfile-uv.jinja"
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- '**.py'
- 'requirements.txt'
- '.github/workflows/*.yml'
- 'requirements-tests.txt'
- 'cicd/cicd.sh'
- 'cicd/Dockerfile.jinja'
types: [opened, synchronize, reopened, ready_for_review]
paths:
- "**.py"
- "pyproject.toml"
- ".github/workflows/*.yml"
- "cicd/cicd.sh"
- "cicd/Dockerfile-uv.jinja"
workflow_dispatch:
# Cancel jobs on the same ref if a new one is triggered
@@ -33,6 +31,7 @@ permissions:
env:
TRANSFORMERS_IS_CI: "yes"
UV_SYSTEM_PYTHON: "1"
jobs:
pre-commit:
@@ -44,7 +43,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: 'pip' # caching pip dependencies
cache: "pip" # caching pip dependencies
- uses: pre-commit/action@v3.0.1
env:
SKIP: no-commit-to-branch
@@ -94,32 +93,25 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Install PyTorch
run: |
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
uv pip install torch==${{ matrix.pytorch_version }} torchvision
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
- name: Install dependencies
run: |
pip3 show torch
pip3 install --no-cache-dir --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
uv pip install --no-build-isolation -e . --override /tmp/torch-pin.txt
python scripts/cutcrossentropy_install.py --uv | sh
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
- name: Ensure axolotl CLI was installed
run: |
@@ -188,38 +180,42 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 setuptools_scm build wheel psutil
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Install PyTorch
run: |
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
uv pip install torch==${{ matrix.pytorch_version }} torchvision
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
- name: Install dependencies
run: |
pip3 show torch
uv pip install packaging setuptools_scm build wheel psutil
python -m build --no-isolation --sdist
pip3 install --no-cache-dir --no-build-isolation dist/axolotl*.tar.gz
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
uv pip install --no-build-isolation dist/axolotl*.tar.gz --override /tmp/torch-pin.txt
python scripts/cutcrossentropy_install.py --uv | sh
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Verify agent docs are discoverable
run: |
# Agent docs live in docs/agents/ (source of truth) and are resolved
# at runtime from the repo checkout or via `axolotl fetch docs`
axolotl agent-docs --list
axolotl agent-docs | grep -q "Fine-tuning framework"
axolotl agent-docs grpo | grep -q "GRPO"
axolotl agent-docs sft | grep -q "SFT"
python -c "from axolotl.cli.agent_docs import get_doc, list_topics; assert len(list_topics()) >= 5; assert 'GRPO' in get_doc('grpo')"
- name: Show HF cache
run: hf cache ls
@@ -281,7 +277,6 @@ jobs:
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -302,7 +297,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
@@ -364,7 +359,7 @@ jobs:
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -16,6 +16,9 @@ axolotl inference config.yaml # Interactive inference
axolotl merge-lora config.yaml # Merge LoRA adapter into base model
axolotl vllm-serve config.yaml # Start vLLM server for GRPO/EBFT training
axolotl fetch examples # Download example configs
axolotl agent-docs # Show agent-optimized docs (bundled with pip package)
axolotl agent-docs grpo # Topic-specific agent reference
axolotl config-schema # Dump config JSON schema
```
## Training Methods
@@ -35,6 +38,8 @@ Agent-specific references:
- [docs/agents/grpo.md](docs/agents/grpo.md) — GRPO online RL with reward functions
- [docs/agents/reward_modelling.md](docs/agents/reward_modelling.md) — outcome and process reward models
- [docs/agents/pretraining.md](docs/agents/pretraining.md) — continual pretraining
- [docs/agents/model_architectures.md](docs/agents/model_architectures.md) — model-specific quirks (Gemma4, Qwen3.5 MoE, etc.)
- [docs/agents/new_model_support.md](docs/agents/new_model_support.md) — debugging and adding support for new model architectures
## Config Pattern

View File

@@ -1,6 +1,7 @@
include requirements.txt
include README.md
include LICENSE
include src/setuptools_axolotl_dynamic_dependencies.py
include VERSION
include src/axolotl/utils/chat_templates/templates/*.jinja
include AGENTS.md
recursive-include docs/agents *.md
recursive-include axolotl *.py

View File

@@ -86,7 +86,7 @@ Features:
**Requirements**:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- Python >=3.11 (3.12 recommended)
- PyTorch ≥2.9.1
### Google Colab
@@ -95,11 +95,19 @@ Features:
### Installation
#### Using pip
```bash
pip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# install uv if you don't already have it installed (restart shell after)
curl -LsSf https://astral.sh/uv/install.sh | sh
# change depending on system
export UV_TORCH_BACKEND=cu128
# create a new virtual environment
uv venv --python 3.12
source .venv/bin/activate
uv pip install torch==2.10.0 torchvision
uv pip install --no-build-isolation axolotl[deepspeed]
# Download example axolotl configs, deepspeed configs
axolotl fetch examples
@@ -110,7 +118,7 @@ axolotl fetch deepspeed_configs # OPTIONAL
Installing with Docker can be less error prone than installing in your own environment.
```bash
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
docker run --gpus '"all"' --ipc=host --rm -it axolotlai/axolotl:main-latest
```
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
@@ -157,6 +165,29 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
- [API Reference](https://docs.axolotl.ai/docs/api/) - Auto-generated code documentation
- [FAQ](https://docs.axolotl.ai/docs/faq.html) - Frequently asked questions
## AI Agent Support
Axolotl ships with built-in documentation optimized for AI coding agents (Claude Code, Cursor, Copilot, etc.). These docs are bundled with the pip package — no repo clone needed.
```bash
# Show overview and available training methods
axolotl agent-docs
# Topic-specific references
axolotl agent-docs sft # supervised fine-tuning
axolotl agent-docs grpo # GRPO online RL
axolotl agent-docs preference_tuning # DPO, KTO, ORPO, SimPO
axolotl agent-docs reward_modelling # outcome and process reward models
axolotl agent-docs pretraining # continual pretraining
axolotl agent-docs --list # list all topics
# Dump config schema for programmatic use
axolotl config-schema
axolotl config-schema --field adapter
```
If you're working with the source repo, agent docs are also available at `docs/agents/` and the project overview is in `AGENTS.md`.
## 🤝 Getting Help
- Join our [Discord community](https://discord.gg/HhrNrHJPRb) for support

View File

@@ -134,7 +134,6 @@ quartodoc:
- monkeypatch.stablelm_attn_hijack_flash
- monkeypatch.trainer_fsdp_optim
- monkeypatch.transformers_fa_utils
- monkeypatch.unsloth_
- monkeypatch.data.batch_dataset_fetcher
- monkeypatch.mixtral
- monkeypatch.gradient_checkpointing.offload_cpu
@@ -327,7 +326,6 @@ website:
- section: "Advanced Features"
contents:
- docs/fsdp_qlora.qmd
- docs/unsloth.qmd
- docs/torchao.qmd
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd

View File

@@ -22,15 +22,6 @@ WORKDIR /workspace/axolotl
RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN uv pip install packaging==26.0 setuptools==78.1.1
RUN uv pip install torchvision
RUN uv pip uninstall causal_conv1d
@@ -40,11 +31,21 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py --uv | sh
# Override with nightly HF packages for nightly builds
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
uv pip install --no-deps \
"transformers @ git+https://github.com/huggingface/transformers.git@main" \
"peft @ git+https://github.com/huggingface/peft.git@main" \
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
"trl @ git+https://github.com/huggingface/trl.git@main" \
"datasets @ git+https://github.com/huggingface/datasets.git@main"; \
fi
RUN python scripts/cutcrossentropy_install.py --uv | sh
# So we can test the Docker image
RUN uv pip install -r requirements-dev.txt -r requirements-tests.txt
RUN uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \

View File

@@ -1,54 +0,0 @@
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}"
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
ENV AXOLOTL_DATASET_NUM_PROC="8"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN pip install packaging==26.0 setuptools==78.1.1 psutil
RUN pip uninstall -y causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh
RUN python scripts/cutcrossentropy_install.py | sh
# So we can test the Docker image
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
# helper for huggingface-login cli
RUN git config --global credential.helper store

View File

@@ -1,7 +1,7 @@
#!/bin/bash
set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__, f'Expected torch $PYTORCH_VERSION but got {torch.__version__}'"
set -o pipefail
for i in 1 2 3; do

View File

@@ -17,7 +17,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile-uv.jinja")
df_template = template_env.get_template(dockerfile)
df_args = {

View File

@@ -16,7 +16,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile-uv.jinja")
df_template = template_env.get_template(dockerfile)
df_args = {

View File

@@ -32,7 +32,7 @@ RUN if [ "$TARGETARCH" = "arm64" ]; then \
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \ python scripts/unsloth_install.py | sh && \
fi && \
python scripts/cutcrossentropy_install.py | sh && \
pip install pytest && \
pip cache purge

View File

@@ -33,7 +33,6 @@ RUN if [ "$TARGETARCH" = "arm64" ]; then \
else \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \
python scripts/unsloth_install.py --uv | sh && \
python scripts/cutcrossentropy_install.py --uv | sh && \
uv pip install pytest && \
uv cache clean

View File

@@ -0,0 +1,198 @@
# Model Architectures — Agent Reference
Model-specific quirks, required settings, and known issues. Check this before debugging training failures on specific model families.
## VLM (Vision Language Model) Quick Start
All VLM configs require these four lines:
```yaml
processor_type: AutoProcessor
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
```
Decision tree for VLM config:
```text
Is the model multimodal (has vision/audio encoder)?
├─ YES: Add `freeze_mm_modules: true` if training text only
│ Add `chat_template: <model_template>` (e.g. gemma4, qwen3_5, gemma3)
│ LoRA: use regex `lora_target_modules` to restrict to language model
└─ NO: Train as a regular text model
Is the model MoE (e.g. Gemma4 26B-A4B, Qwen3.5 35B-A3B)?
├─ YES: Add `lora_target_parameters` for expert LoRA
│ Consider ScatterMoE kernels (see Plugins section)
└─ NO: Standard LoRA config
```
## Plugins & Optimizations
### Cut Cross Entropy (CCE)
Computes loss from hidden states + lm_head weight without materializing the full logits tensor, saving significant VRAM. Install if not already present:
```bash
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
```
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
```
### ScatterMoE Kernels
Fuses expert + LoRA computation into a single kernel for MoE models. Significant speedup for models with many experts.
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
# Expert LoRA targets (3D parameter tensors, not nn.Linear):
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
Supported: Gemma4 (`gemma4_text`), Mixtral, Qwen MoE variants. The plugin auto-detects model type and routing function. Without ScatterMoE, expert LoRA still works but runs base expert matmul and LoRA as separate operations.
## Gemma 4
**Models**: `google/gemma-4-26B-A4B` (MoE), `google/gemma-4-31B` (dense), `google/gemma-4-E2B`, `google/gemma-4-E4B`
**Architecture**: Multimodal wrapper (`Gemma4ForConditionalGeneration`) over a text backbone (`Gemma4TextModel`), with optional vision/audio encoders. All Gemma4 HF repos have `model_type: "gemma4"` — even text-only variants load as multimodal with a vision tower.
### Required settings
```yaml
# Always needed for Gemma4:
freeze_mm_modules: true # Freeze vision/audio encoders for text-only training
gradient_checkpointing_kwargs:
use_reentrant: false # Shared per-layer norms cause "marked ready twice" with reentrant
# LoRA target — restrict to language model only (DO NOT use lora_target_linear: true):
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
```
### Auto-detection
Axolotl auto-detects Gemma4 and applies:
- `use_reentrant: false` for gradient checkpointing
- `ddp_find_unused_parameters: true` for DDP (skipped when `activation_offloading: true`)
### Multi-GPU
| Strategy | Works? | Notes |
|----------|--------|-------|
| DDP | Yes | Auto-sets `ddp_find_unused_parameters=True` |
| DDP + activation_offloading | Yes | `find_unused_parameters` is skipped (conflicts with checkpoint wrappers) |
| FSDP1 | No | OOM during dequantization/sharding with QLoRA |
| FSDP2 | Yes | Use `Gemma4TextDecoderLayer` (not `Gemma4DecoderLayer`) as wrap class |
| FSDP2 + activation_offloading | Yes | Lowest VRAM (~26 GiB/GPU for 26B-A4B) |
FSDP2 config:
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_version: 2
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer
```
### MoE (26B-A4B)
- `enable_moe_block: true`, 256 experts, top-k routing
- No separate `SparseMoeBlock` — MoE is embedded in each decoder layer
- Expert LoRA targets 3D parameter tensors:
```yaml
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
- ScatterMoE kernel acceleration:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
```
### VLM (Vision) Training
All Gemma4 models load as `Gemma4ForConditionalGeneration` with a vision tower. No custom `ProcessingStrategy` needed — the base class auto-detects the image token.
```yaml
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B
processor_type: AutoProcessor
freeze_mm_modules: true
chat_template: gemma4
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
```
A starting VLM loss of ~8-15 is typical. In most runs, loss converges below 1.0 within ~30-50 steps, though results may vary across configurations.
For the 26B-A4B MoE variant with ScatterMoE + expert LoRA + CCE, add:
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
### Common issues
| Symptom | Cause | Fix |
|---------|-------|-----|
| `mm_token_type_ids is required` in DDP | `model.config` not accessible through DDP wrapper | Already fixed — `unwrap_model()` in `compute_loss` and `prediction_step` |
| `marked a variable ready twice` in DDP | `ddp_find_unused_parameters=True` + activation_offloading checkpoint wrappers | Auto-handled — `find_unused_parameters` is skipped when `activation_offloading: true` |
| Loss ~12 instead of ~0.5 | Using `lora_target_linear: true` (applies LoRA to vision/audio modules) | Use the regex `lora_target_modules` pattern instead |
| FSDP2 `Could not find Gemma4AudioLayer` | Auto-wrap detects `_no_split_modules` including audio layers that don't exist | Explicitly set `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer` |
| `Gemma4ClippableLinear not supported` by PEFT | Vision tower uses a non-standard linear wrapper | Axolotl patches this automatically via `_patch_peft_clippable_linear()` |
### E2B/E4B dense models
These have `hidden_size_per_layer_input: 256` (per-layer input embeddings) and `attention_k_eq_v: False`. Known issue: loss starts higher than expected (~12 vs ~0.5 for 26B). Root cause under investigation — may be related to the per-layer input mechanism or the `Gemma4ForConditionalGeneration` loss computation.
## Gemma 3
**Models**: `google/gemma-3-*`
- `ddp_find_unused_parameters: true` needed (multimodal unused params)
- `use_reentrant: false` recommended
- Attention mask must be dropped for sample packing (handled automatically)
- Multi-GPU test currently skipped (`tests/e2e/multigpu/test_gemma3.py`)
## Qwen 3.5 MoE
**Models**: `Qwen/Qwen3.5-35B-A3B`
- Hybrid architecture: DeltaNet linear attention (30 layers) + full attention (10 layers)
- 256 experts, 8 active per token
- Known weight scale drift in late DeltaNet layers (36-38) due to AdamW + rare expert interaction
- Fix: `normalize_weight_scales` config to detect and rescale outliers:
```yaml
normalize_weight_scales:
- name_pattern: 'linear_attn\.conv1d\.weight'
threshold: 1.3
```
## General MoE Notes
- `lora_target_linear: true` with multimodal MoE models will apply LoRA to ALL linear modules including vision/audio encoders — use regex `lora_target_modules` to restrict to language model only
- Rare experts get larger effective learning rate from AdamW (small second-moment estimates) — can cause weight drift in recurrent/SSM components. Use `normalize_weight_scales` with `dry_run: true` to detect.
- For ScatterMoE kernel support, set `experts_implementation: scattermoe` and add the KernelsPlugin

View File

@@ -0,0 +1,181 @@
# New Model Support — Agent Reference
Guide for debugging and adding support for new model architectures in axolotl. Based on lessons learned from Gemma4, Gemma3, Qwen2-VL, and other multimodal/MoE models.
## Quick Validation Checklist
When testing a new model, run through these checks in order:
1. **Does the model load?** `axolotl preprocess config.yaml` — catches config schema errors
2. **Does LoRA apply?** Check for "Unsupported layer type" warnings from PEFT
3. **Is the initial loss sane?** First-step loss for a pretrained model should be 0.52.0 for SFT
4. **Does sample packing work?** Compare loss with `sample_packing: true` vs `false` — should be similar
5. **Is CCE active?** Check for "Applying Cut Cross Entropy" log and verify peak VRAM is lower
## Loss Debugging
### Expected initial loss
A pretrained model doing SFT should start with loss roughly in the 0.52.0 range. If loss starts above 3.0, something is wrong. If it's near `log(vocab_size)` (≈ 12 for 262K vocab), the model is predicting at random — attention masking or model weights are broken.
### Direct comparison technique
The fastest way to isolate a loss issue — bypass the trainer entirely:
```python
# Load model via axolotl's pipeline (applies all patches)
from axolotl.cli.config import load_cfg
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.loaders.model import ModelLoader
cfg = load_cfg("your_config.yaml")
normalize_config(cfg)
prepare_plugins(cfg)
tokenizer = load_tokenizer(cfg)
model, _ = ModelLoader(cfg, tokenizer).load()
# Forward pass on preprocessed data
model.train()
out = model(input_ids, labels=labels)
print(f"Direct loss: {out.loss.item()}") # Compare to trainer's reported loss
```
If direct loss is correct (~1.0) but trainer reports 34x higher, check `model_accepts_loss_kwargs` (see below).
### `model_accepts_loss_kwargs` inflation
HF Trainer checks if the model's `forward()` has `**kwargs` and sets `model_accepts_loss_kwargs=True`. This changes loss normalization: the trainer does NOT divide loss by `gradient_accumulation_steps` before logging. The gradient is correct — only the logged loss is inflated.
**Symptom**: Logged loss ≈ actual_loss × gradient_accumulation_steps.
**Which models are affected**: Any model with `**kwargs` in forward (common in multimodal models for extra inputs like `mm_token_type_ids`, `pixel_values`, etc.).
**Fix location**: `src/axolotl/core/trainers/base.py` `__init__()` — after `super().__init__()`, check if the unwrapped model actually has `num_items_in_batch` in its forward signature. If not, set `self.model_accepts_loss_kwargs = False`.
## Multimodal Models (ForConditionalGeneration)
Many recent models use `ForConditionalGeneration` as the top-level class, not `ForCausalLM`:
- Gemma3 → `Gemma3ForConditionalGeneration`
- Gemma4 → `Gemma4ForConditionalGeneration`
- Qwen2-VL → `Qwen2VLForConditionalGeneration`
- LLaVA → `LlavaForConditionalGeneration`
### Why this matters
| Component | Targets `ForCausalLM` | Needs `ForConditionalGeneration` |
|-----------|----------------------|--------------------------------|
| CCE patches | ✅ (default) | ❌ silently inactive if not patched |
| PEFT LoRA | ✅ | May fail on custom layer types |
| HF Trainer label handling | ✅ | May need extra inputs |
### Required extra inputs
Multimodal models require special inputs during training even for text-only data:
| Model | Required Input | Value for Text-Only |
|-------|---------------|-------------------|
| Gemma4 | `mm_token_type_ids` | `torch.zeros_like(input_ids)` |
| Gemma3 | `token_type_ids` | `torch.zeros_like(input_ids)` |
Auto-inject in `compute_loss()` when not provided by the data collator. See `core/trainers/base.py`.
### Custom layer types and PEFT
Vision towers often use custom module wrappers that PEFT doesn't support:
| Model | Custom Layer | Wraps | Fix |
|-------|-------------|-------|-----|
| Gemma4 | `Gemma4ClippableLinear` | `nn.Linear` | Redirect to `.linear` child |
Fix location: `src/axolotl/loaders/adapter.py` `_patch_peft_clippable_linear()`.
## Sample Packing
### How packed sequence detection works (transformers ≥ 5.x)
`transformers.masking_utils._preprocess_mask_arguments()` detects packed sequences from `position_ids` resets. But **only when `attention_mask is None`**:
```python
# From masking_utils.py:
if position_ids is not None and attention_mask is None and past_key_values is None:
packed_sequence_mask = find_packed_sequence_indices(position_ids)
```
If the collator provides an all-ones `attention_mask`, packing detection is **skipped** and the model builds a single causal mask spanning all packed sequences → cross-sequence attention leakage → very high loss.
### Fix for models using `create_causal_mask_mapping`
For Gemma3, Gemma4, and similar models that use the new transformers masking system, remove `attention_mask` from inputs when sample packing is active:
```python
# In compute_loss():
if (
self.args.sample_packing
and model_type in ("gemma4", "gemma3")
and "attention_mask" in inputs
and "position_ids" in inputs
):
del inputs["attention_mask"]
```
Fix location: `src/axolotl/core/trainers/base.py` `compute_loss()`.
### Models that DON'T need this fix
Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2, etc.) handle sample packing via axolotl's multipack attention monkeypatch instead. Only models using the new `create_causal_mask_mapping` / `create_causal_mask` masking system need the `attention_mask` removal.
## Attention Backend Selection
| Backend | Config | head_dim limit | torch_compile | Notes |
|---------|--------|---------------|---------------|-------|
| FA2 | `flash_attention: true` | 256 | ✅ | Fastest when supported |
| FA4 | auto with `flash_attention: true` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
| SDPA | `sdp_attention: true` | None | ✅ | Universal fallback |
| flex | `flex_attention: true` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
| eager | neither set | None | ✅ | Slowest, always works |
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.
**head_dim gotcha**: The 256 limit is specific to flash-attn CUDA kernels, NOT PyTorch-level. SDPA and flex_attention both handle arbitrary head_dim. Models with `global_head_dim > 256` (Gemma4: 512) must use SDPA or flex.
**flex + compile gotcha**: `torch_compile` with flex_attention can hit Triton shared memory OOM for large head_dim. Falls back to eager per-function (not a crash, but slower). Unsloth disables flex for Gemma4 for this reason.
## Cut Cross Entropy (CCE)
### How CCE patches work
CCE replaces the model's `forward()` with a fused version that computes loss from hidden states + lm_head weight without materializing the full logits tensor. This saves ~`batch × seq_len × vocab_size × dtype_bytes` of VRAM.
### Adding CCE for a new model
1. Check if the model type is in `cut_cross_entropy.transformers.patch.PATCH_FNS`
2. If not, axolotl's generic fallback (`integrations/cut_cross_entropy/__init__.py` `patch_llama_like()`) patches `{Prefix}ForCausalLM.forward` with `cce_forward`
3. For multimodal models (`ForConditionalGeneration`), a model-specific patch is needed in `ml-cross-entropy` repo
4. The multimodal `cce_forward` must accept all extra kwargs (pixel_values, mm_token_type_ids, etc.) and pop any that would conflict before calling `self.model()`
### Common CCE pitfall
If CCE appears active (log says "Applying Cut Cross Entropy") but peak VRAM doesn't decrease, check which class was patched. If the model loads as `ForConditionalGeneration` but CCE patched `ForCausalLM`, the patch is silently inactive.
## MoE Models
### Dense MLP vs MoE experts
Some MoE models (e.g., Gemma4) have BOTH dense MLP layers and MoE expert layers at every decoder layer:
- `gate_proj/up_proj/down_proj` → targets the **dense MLP** (`Gemma4TextMLP`)
- `experts.gate_up_proj/experts.down_proj` → targets the **MoE experts** (`Gemma4TextExperts`)
LoRA on the dense MLP works normally. Expert LoRA via `lora_target_parameters` requires PEFT support for the specific expert module type (may warn "Unsupported layer type").
### ScatterMoE kernels
`use_scattermoe: true` with `experts_implementation: scattermoe` registers fused expert kernels via transformers' `ExpertsInterface`. Significant speedup for MoE models. Requires the kernels plugin:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
```
## Where to Add Model-Specific Fixes
| What | Where | Example |
|------|-------|---------|
| Missing forward inputs | `core/trainers/base.py` `compute_loss()` | mm_token_type_ids injection |
| Attention mask fixes | `core/trainers/base.py` `compute_loss()` | Sample packing mask removal |
| Loss logging fixes | `core/trainers/base.py` `__init__()` | model_accepts_loss_kwargs override |
| PEFT/LoRA patches | `loaders/adapter.py` | ClippableLinear redirect |
| Attention patches | `monkeypatch/attention/` | FA4 tuple fix |
| Model-specific patches | `loaders/patch_manager.py` `_apply_model_specific_patches()` | Llama4, Kimi, NemotronH |
| CCE patches | `ml-cross-entropy` repo `transformers/` | Per-model cce_forward |
| Example configs | `examples/<model>/` | Validated YAML |
| Config validation | `utils/schemas/validation.py` | Compatibility checks |

View File

@@ -91,6 +91,30 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
| FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` |
| DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
## Profiling
To profile training and identify optimization opportunities:
```yaml
# Profile steps 3-7 (after warmup/autotuning settles)
profiler_steps_start: 3
profiler_steps: 5
```
This produces `profiler_trace.json` (Chrome trace) and `snapshot.pickle` (memory snapshot) in `output_dir`.
View the Chrome trace at `chrome://tracing`.
To programmatically inspect the trace:
```bash
python scripts/analyze_profile.py output_dir/
```
The trace shows per-kernel CUDA times, memory allocations, and operator-level breakdown. Look for:
- **Large matmul kernels**: candidates for fusion or quantization
- **Memory copies (H2D/D2H)**: unnecessary data movement
- **Small frequent kernels**: candidates for kernel fusion
- **Gaps between kernels**: pipeline bubbles from CPU overhead
Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd)
## File Map

View File

@@ -108,6 +108,14 @@ datasets:
type: chat_template
```
::: {.callout-tip}
`chat_template_jinja` also accepts a file path to a `.jinja2` file instead of an inline string:
```yaml
chat_template_jinja: ./path/to/my_template.jinja2
```
:::
::: {.callout-important}
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
:::
@@ -294,6 +302,113 @@ datasets:
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
:::
#### Content parts with per-part training control
Instead of using character offsets with `train_detail`, you can split a message's content into a list of parts, each with its own training flag. This is useful when you want to mask specific sections of a response (e.g., mask reasoning but train on the answer).
```{.json filename="data.jsonl"}
{
"messages": [
{"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me think step by step...", "train": false},
{"type": "text", "text": " The answer is 4.", "train": true}
]
}
]
}
```
The configuration is the same as standard `chat_template` — no extra fields needed:
```yaml
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
Each content part supports:
- `type`: `"text"` (required)
- `text`: the text value (also accepts `content` or `value` as the key)
- `train`: `true`/`false` (optional) — whether to train on this part
- `weight`: `0`/`1` (optional) — alternative to `train`
If a part has no `train` or `weight` flag, it inherits the turn-level training decision (from `roles_to_train`, `message_field_training`, or `train_on_inputs`).
::: {.callout-warning title="Whitespace at part boundaries"}
BPE tokenizers (used by Llama, Qwen, Mistral, GPT, etc.) prepend spaces to word tokens. For example, `" answer"` is a single token — the space is part of it. This means **where you place whitespace between content parts matters**:
**Split BEFORE spaces** (space goes with the next part):
```json
[
{"type": "text", "text": "Let me think...", "train": false},
{"type": "text", "text": " The answer is 4.", "train": true}
]
```
**DON'T put trailing spaces** on a part (the space merges with the next word into one token that straddles the boundary, and straddling tokens are masked):
```json
[
{"type": "text", "text": "Let me think... ", "train": false},
{"type": "text", "text": "The answer is 4.", "train": true}
]
```
In the bad example, `" The"` becomes a single token that spans both parts. Because it straddles the boundary, it is conservatively **masked** (not trained) — even though the second part has `train: true`.
**Newlines** typically merge with preceding punctuation (e.g., `":\n"` is one token). Keep newlines with the preceding part:
```json
[
{"type": "text", "text": "Thinking:\n", "train": false},
{"type": "text", "text": "The answer is 4.", "train": true}
]
```
Axolotl will log a warning if it detects trailing whitespace at a boundary between parts with different training flags.
:::
::: {.callout-note}
When all content parts in a message are strings, they are concatenated before being passed to the chat template. This means content parts work with **any** Jinja template — the template sees a plain string, and the per-part training flags are applied during tokenization.
:::
##### Per-part training on reasoning_content
For templates that support a separate `reasoning_content` field (e.g., `qwen3`), the same content-parts format works on `reasoning_content`. This is useful for masking incorrect reasoning steps while training on self-corrections:
```{.json filename="data.jsonl"}
{
"messages": [
{"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]},
{
"role": "assistant",
"reasoning_content": [
{"type": "text", "text": "Hmm maybe 2+2=5.", "train": false},
{"type": "text", "text": " Wait no, 2+2=4.", "train": true}
],
"content": [
{"type": "text", "text": "The answer is 4.", "train": true}
]
}
]
}
```
The `reasoning_content` and `content` fields are handled independently — each has its own token boundaries and per-part masking. No additional configuration is needed beyond what the template already requires.
::: {.callout-tip}
When `reasoning_content` is provided as a separate field, `split_thinking` is not needed — the reasoning is already separated from the content in the data.
:::
The same whitespace rules apply to `reasoning_content` parts as to `content` parts — split before spaces, keep newlines with the preceding part.
#### Reasoning split
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.

View File

@@ -76,8 +76,9 @@ datasets:
Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/userguide/development_mode.html) of Axolotl, which ensures that changes you make to the code are reflected at runtime. Run the following commands from the root of this project:
```bash
pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
export UV_TORCH_BACKEND=cu128 # or cu130
uv sync --extra flash-attn --extra deepspeed --group dev --group test
source .venv/bin/activate
```
#### Remote Hosts
@@ -208,17 +209,17 @@ cd axolotl
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl-uv:main-latest
```
>[!Tip]
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
You will now be in the container. Next, perform an editable install of Axolotl:
You will now be in the container. Next, install Axolotl with dev dependencies:
```bash
pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
uv sync --extra flash-attn --extra deepspeed --group dev --group test
source .venv/bin/activate
```
### Attach To Container

View File

@@ -6,23 +6,30 @@ format:
toc-depth: 4
---
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
This section describes the different Docker images that are released by AxolotlAI at
[Docker Hub](https://hub.docker.com/u/axolotlai).
::: {.callout-important}
For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8.
For Blackwell GPUs, please use the tags with PyTorch 2.9.1 and CUDA 12.8.
:::
::: {.callout-tip}
Each image below is available in a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with
a relocatable venv (`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
(e.g. `axolotlai/axolotl-base-uv`). Tags follow the same format. We recommend the uv images for new deployments.
:::
## Base
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image.
It includes python, torch, git, git-lfs, awscli, pydantic, and more.
#### Image
```
axolotlai/axolotl-base
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)
| Variant | Image | Docker Hub |
|---------|-------|------------|
| pip | `axolotlai/axolotl-base` | [Link](https://hub.docker.com/r/axolotlai/axolotl-base) |
| uv | `axolotlai/axolotl-base-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-base-uv) |
#### Tags format
@@ -32,8 +39,10 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
Tags examples:
- `main-base-py3.11-cu128-2.8.0`
- `main-base-py3.11-cu128-2.9.1`
- `main-base-py3.12-cu128-2.10.0`
- `main-base-py3.12-cu130-2.9.1`
- `main-base-py3.12-cu130-2.10.0`
## Main
@@ -41,11 +50,10 @@ The main image is the image that is used to run Axolotl. It is based on the `axo
#### Image
```
axolotlai/axolotl
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
| Variant | Image | Docker Hub |
|---------|-------|------------|
| pip | `axolotlai/axolotl` | [Link](https://hub.docker.com/r/axolotlai/axolotl) |
| uv | `axolotlai/axolotl-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-uv) |
#### Tags format {#sec-main-tags}
@@ -53,7 +61,7 @@ Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
# on push to main
main-py{python_version}-cu{cuda_version}-{pytorch_version}
# latest main (currently torch 2.6.0, python 3.11, cuda 12.4)
# latest main (currently torch 2.9.1, python 3.11, cuda 12.8)
main-latest
# nightly build
@@ -71,11 +79,12 @@ There may be some extra tags appended to the image, like `-vllm` which installs
Tags examples:
- `main-py3.11-cu128-2.8.0`
- `main-py3.11-cu128-2.9.1`
- `main-py3.12-cu128-2.10.0`
- `main-py3.12-cu130-2.9.1`
- `main-py3.12-cu130-2.10.0`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu126-2.6.0`
- `main-20260315-py3.11-cu128-2.9.1`
- `0.12.0`
## Cloud
@@ -90,11 +99,10 @@ Jupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variab
#### Image
```
axolotlai/axolotl-cloud
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)
| Variant | Image | Docker Hub |
|---------|-------|------------|
| pip | `axolotlai/axolotl-cloud` | [Link](https://hub.docker.com/r/axolotlai/axolotl-cloud) |
| uv | `axolotlai/axolotl-cloud-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-cloud-uv) |
#### Tags format

View File

@@ -15,64 +15,30 @@ This guide covers all the ways you can install and set up Axolotl for your envir
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.11
- PyTorch ≥2.6.0
- PyTorch ≥2.9.0
## Installation Methods {#sec-installation-methods}
::: {.callout-important}
Please make sure to have Pytorch installed before installing Axolotl in your local environment.
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
:::
## Installation {#sec-installation}
::: {.callout-important}
For Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.
:::
### PyPI Installation (Recommended) {#sec-pypi}
### Quick Install {#sec-uv}
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
```
Axolotl uses [uv](https://docs.astral.sh/uv/) as its package manager. uv is a fast, reliable Python package installer and resolver built in Rust.
We use `--no-build-isolation` in order to detect the installed PyTorch version (if
installed) in order not to clobber it, and so that we set the correct version of
dependencies that are specific to the PyTorch version or other installed
co-dependencies.
### uv Installation {#sec-uv}
uv is a fast, reliable Python package installer and resolver built in Rust. It offers significant performance improvements over pip and provides better dependency resolution, making it an excellent choice for complex environments.
Install uv if not already installed
Install uv if not already installed:
```{.bash}
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
```
Choose your CUDA version to use with PyTorch; e.g. `cu124`, `cu126`, `cu128`,
then create the venv and activate
Choose your CUDA version (e.g. `cu128`, `cu130`), create a venv, and install:
```{.bash}
export UV_TORCH_BACKEND=cu126
export UV_TORCH_BACKEND=cu128 # or cu130
uv venv --no-project --relocatable
source .venv/bin/activate
```
Install PyTorch
- PyTorch 2.6.0 recommended
```{.bash}
uv pip install packaging setuptools wheel
uv pip install torch==2.6.0
uv pip install awscli pydantic
```
Install axolotl from PyPi
```{.bash}
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn]
# optionally install with vLLM if you're using torch==2.6.0 and want to train w/ GRPO
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn,vllm]
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
```
### Edge/Development Build {#sec-edge-build}
@@ -82,14 +48,17 @@ For the latest features between releases:
```{.bash}
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
export UV_TORCH_BACKEND=cu128 # or cu130
uv sync --extra flash-attn --extra deepspeed
source .venv/bin/activate
```
`uv sync` creates a `.venv`, installs exact pinned versions from `uv.lock`, and sets up an editable install automatically.
### Docker {#sec-docker}
```{.bash}
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
docker run --gpus '"all"' --rm -it --ipc=host axolotlai/axolotl-uv:main-latest
```
For development with Docker:
@@ -106,12 +75,12 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
--ulimit memlock=-1 --ulimit stack=67108864 \
--mount type=bind,src="${PWD}",target=/workspace/axolotl \
-v ${HOME}/.cache/huggingface:/root/.cache/huggingface \
axolotlai/axolotl:main-latest
axolotlai/axolotl-uv:main-latest
```
:::
::: {.callout-important}
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.9.1`.
For Blackwell GPUs, please use `axolotlai/axolotl-uv:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud-uv:main-py3.11-cu128-2.9.1`.
:::
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
@@ -122,7 +91,7 @@ Please refer to the [Docker documentation](docker.qmd) for more information on t
For providers supporting Docker:
- Use `axolotlai/axolotl-cloud:main-latest`
- Use `axolotlai/axolotl-cloud-uv:main-latest`
- Available on:
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=axolotl&utm_medium=partner&utm_campaign=template_launch_july2025&utm_content=docs_link)
@@ -141,7 +110,7 @@ For providers supporting Docker:
### macOS {#sec-macos}
```{.bash}
pip3 install --no-build-isolation -e '.'
uv pip install --no-build-isolation -e '.'
```
See @sec-troubleshooting for Mac-specific issues.
@@ -152,21 +121,44 @@ See @sec-troubleshooting for Mac-specific issues.
We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
:::
## Environment Managers {#sec-env-managers}
## Migrating from pip to uv {#sec-migrating}
### Conda/Pip venv {#sec-conda}
If you have an existing pip-based Axolotl installation, you can migrate to uv:
1. Install Python ≥3.11
2. Install PyTorch: https://pytorch.org/get-started/locally/
3. Install Axolotl:
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
4. (Optional) Login to Hugging Face:
```{.bash}
hf auth login
```
```{.bash}
# Install uv
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
# Create a fresh venv (recommended for a clean start)
export UV_TORCH_BACKEND=cu128 # or cu130
uv venv --no-project --relocatable
source .venv/bin/activate
# Reinstall axolotl
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
```
## Using pip (Alternative) {#sec-pip}
If you are unable to install uv, you can still use pip directly.
::: {.callout-important}
Please make sure to have PyTorch installed before installing Axolotl with pip.
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
:::
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
```
For editable/development installs:
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
## Troubleshooting {#sec-troubleshooting}

View File

@@ -8,6 +8,7 @@ format:
## Supported Models
- [Gemma-4](#sec-gemma-4) *(NEW)*
- [Mllama](#sec-mllama)
- [Llama4](#sec-llama4)
- [Pixtral](#sec-pixtral)
@@ -138,6 +139,40 @@ base_model: mistralai/Voxtral-Mini-3B-2507
processor_type: VoxtralProcessor
```
### Gemma-4 {#sec-gemma-4}
All Gemma 4 variants (E2B, E4B, 26B-A4B, 31B) load as multimodal models even for text-only training.
```yaml
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B, 31B
chat_template: gemma4
freeze_mm_modules: true # freeze vision/audio encoders for text-only or vision LoRA
# For the 26B-A4B MoE model, enable ScatterMoE and expert LoRA:
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
# MoE expert LoRA (3D tensors, not nn.Linear) — only for 26B-A4B:
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
::: {.callout-warning}
Gemma 4 VLM training starts with high loss (~8-15). This is expected — see the [training stability guide](training_stability.qmd) for details.
:::
::: {.callout-tip}
For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True`. However, when `activation_offloading: true`, `ddp_find_unused_parameters` is skipped (checkpoint wrappers conflict with it); use `freeze_mm_modules: true` instead to handle unused vision/audio params. For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`.
:::
### Gemma-3 {#sec-gemma-3}
::: {.callout-tip}

View File

@@ -1,53 +0,0 @@
---
title: "Unsloth"
description: "Hyper-optimized QLoRA finetuning for single GPUs"
---
### Overview
Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
standard industry baselines.
::: {.callout-important}
Due to breaking changes in transformers `v4.48.0`, users will need to downgrade to `<=v4.47.1` to use this patch.
This will later be deprecated in favor of [LoRA Optimizations](lora_optims.qmd).
:::
### Installation
The following will install the correct unsloth and extras from source.
```bash
python scripts/unsloth_install.py | sh
```
### Usage
Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.
Our unsloth integration is currently limited to the following model architectures:
- llama
These options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning
```yaml
unsloth_lora_mlp: true
unsloth_lora_qkv: true
unsloth_lora_o: true
```
These options are composable and can be used with multi-gpu finetuning
```yaml
unsloth_cross_entropy_loss: true
unsloth_rms_norm: true
unsloth_rope: true
```
### Limitations
- Single GPU only; e.g. no multi-gpu support
- No deepspeed or FSDP support (requires multi-gpu)
- LoRA + QLoRA support only. No full fine tunes or fp8 support.
- Limited model architecture support. Llama, Phi, Gemma, Mistral only
- No MoE support.

View File

@@ -15,8 +15,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Run one of the finetuning examples below.
@@ -35,7 +34,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
**LFM2-MoE**
```bash
pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
uv pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
# LoRA SFT (1x48GB @ 16.2GiB)
axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml
@@ -45,7 +44,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
```bash
pip uninstall -y causal-conv1d
uv pip uninstall causal-conv1d
```
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).

View File

@@ -15,8 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
uv pip install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
@@ -31,7 +30,7 @@ python scripts/cutcrossentropy_install.py | sh
# For those using our Docker image, use the below path.
export CUDA_HOME=/usr/local/cuda
pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
uv pip install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
```
For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)
@@ -67,7 +66,7 @@ If those didn't help, please try the below solutions:
1. Pass env for CMAKE and try install again:
```bash
Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
Python_EXECUTABLE=$(which python) uv pip install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
```
2. Git clone the repo and manually hardcode python path:
@@ -92,7 +91,7 @@ If those didn't help, please try the below solutions:
```
```bash
pip3 install . --no-build-isolation --no-deps
uv pip install . --no-build-isolation --no-deps
```
## Optimization Guides

View File

@@ -17,8 +17,7 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
uv pip install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
]
},
{

View File

@@ -16,8 +16,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage

View File

@@ -26,8 +26,8 @@ output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
- ^model.language_model.*
- ^lm_head.*
adapter: qlora
lora_r: 32

View File

@@ -26,8 +26,8 @@ output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
- ^model.language_model.*
- ^lm_head.*
adapter: qlora
lora_r: 32

View File

@@ -22,8 +22,8 @@ output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
- ^model.language_model.*
- ^lm_head.*
adapter: qlora
lora_model_dir:

View File

@@ -10,17 +10,16 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. In addition to Axolotl's requirements, Gemma-3n requires:
```bash
pip3 install timm==1.0.17
uv pip install timm==1.0.17
# for loading audio data
pip3 install librosa==0.11.0
uv pip install librosa==0.11.0
```
3. Download sample dataset files

View File

@@ -1,19 +1,12 @@
# Gemma 4 26B-A4B MoE QLoRA with ScatterMoE kernels
#
# Validated: 50 steps on FineTome-100k, loss 7.4 -> 2.4, single RTX 5090 (32GB)
# Validated: 50 steps on FineTome-100k, loss 8.8 -> 1.8, single RTX 5090 (32GB)
# torch_compile=true: 21 GiB peak VRAM, ~230 tok/s, 336s total
#
# Key notes:
# - Flash Attention 2 is NOT supported (global_head_dim=512 > FA2 max of 256).
# Use sdp_attention instead.
# - Gemma 4 is multimodal (text+vision+audio). For text-only SFT, restrict
# LoRA to the text backbone via lora_target_linear_modules regex.
# - MoE experts use `experts_implementation: scattermoe` — Gemma 4 embeds MoE
# directly in the decoder layer (no SparseMoeBlock), so we register ScatterMoE
# via the transformers ExpertsInterface.
# - Expert LoRA targets are `experts.gate_up_proj` / `experts.down_proj`
# (no `mlp.` prefix, unlike Qwen/Mixtral).
# - micro_batch_size: 1 fits 2048 seq_len on 32GB GPU with SDP attention.
# Use micro_batch_size: 4 with 1024 seq_len, or on 48GB+ GPUs.
# - Max sequence length on 32GB GPU: 2048 (micro_batch_size=1, SDP attention).
# 4096 seq_len OOMs due to head_dim=512 math SDP materializing full score matrix.
# Use 48GB+ GPUs for longer sequences or multi-GPU with FSDP.
base_model: google/gemma-4-26B-A4B
@@ -24,7 +17,7 @@ plugins:
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
torch_compile: false
torch_compile: true
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
@@ -54,12 +47,9 @@ lora_r: 16
lora_alpha: 32
lora_dropout: 0
# Restrict LoRA to text backbone only (skip vision/audio encoders).
# lora_target_modules is intentionally empty — all module targeting is done
# via regex in lora_target_linear_modules below.
lora_target_modules: []
lora_target_linear_modules:
- language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj
# Restrict LoRA to text backbone only (skip vision/audio encoders)
# using regex to match only the text decoder attention projections.
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
# MoE expert LoRA (3D Parameter tensors, not nn.Linear)
lora_target_parameters:
@@ -73,7 +63,7 @@ lora_o_kernel: false
bnb_config_kwargs:
bnb_4bit_use_double_quant: true
wandb_project: gemma4-qlora
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
@@ -93,8 +83,7 @@ gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
# FA2 not supported — Gemma4 global_head_dim=512 exceeds FA2 max of 256
flash_attention: false
# FA2 not supported
sdp_attention: true
warmup_ratio: 0.1

View File

@@ -0,0 +1,71 @@
base_model: google/gemma-4-31B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.liger.LigerPlugin
torch_compile: true
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_rms_norm_gated: true
strict: false
chat_template: gemma4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:10%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.05
output_dir: ./outputs/gemma4-31b-qlora-flex
sequence_len: 2048
sample_packing: true
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
# Restrict LoRA to text backbone only (skip vision/audio encoders)
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
bnb_config_kwargs:
bnb_4bit_use_double_quant: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
# FA not supported
flex_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -0,0 +1,69 @@
base_model: google/gemma-4-31B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.liger.LigerPlugin
torch_compile: false
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_rms_norm_gated: true
strict: false
chat_template: gemma4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:10%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.05
output_dir: ./outputs/gemma4-31b-qlora
sequence_len: 2048
sample_packing: true
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
# Restrict LoRA to text backbone only (skip vision/audio encoders)
# using regex to match only the text decoder attention projections.
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
bnb_config_kwargs:
bnb_4bit_use_double_quant: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
# FA not supported
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

60
examples/gemma4/README.md Normal file
View File

@@ -0,0 +1,60 @@
# Finetune Google's Gemma 4 with Axolotl
[Gemma 4](https://huggingface.co/collections/google/gemma-4) is a family of multimodal models from Google. This guide covers how to train them with Axolotl.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
# 26B MoE QLoRA (1x80GB @ ~50 GiB)
axolotl train examples/gemma4/26b-a4b-moe-qlora.yaml
# 31B Dense QLoRA (1x80GB @ ~44 GiB)
axolotl train examples/gemma4/31b-qlora.yaml
# 31B Dense QLoRA Flex Attn (1x80GB @ ~26 GiB)
axolotl train examples/gemma4/31b-qlora-flex.yaml
```
### MoE Expert Quantization & Expert LoRA (26B-A4B only)
The 26B-A4B config uses ScatterMoE kernels via the transformers `ExpertsInterface` and quantizes expert weights on load. To learn about expert quantization, expert LoRA targeting, and related limitations, see the [MoE Expert Quantization](https://docs.axolotl.ai/docs/expert_quantization.html) docs.
## Flex Attention
Reduce ~40% VRAM (at the cost of up to half throughput) by setting the below (shown in `examples/gemma4/31b-qlora-flex.yaml`):
```yaml
torch_compile: true
flex_attention: true
```
This works for both the MoE and Dense model.
## Limitations
- **Flash Attention**: FA2 (max head_dim=256) and FA4 (max head_dim=128) cannot support Gemma 4's `global_head_dim=512`. Use SDP or flex attention instead.
- **LoRA kernels**: Not supported due to KV-sharing layers.
- **lora_target_linear**: Incompatible for multimodal models — use `lora_target_modules` with a regex to restrict LoRA to the text backbone.
### TIPS
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- You can run full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy and has not been tested.
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [Gemma 4 Blog](https://huggingface.co/blog/gemma4)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,62 @@
# Gemma 4 E2B Vision LoRA
#
# Fine-tuning LM LoRA adapters on multimodal Gemma4 with vision/multimodal modules frozen.
# Uses the base ProcessingStrategy (auto-detects image_token from processor).
base_model: google/gemma-4-E2B-it
processor_type: AutoProcessor
freeze_mm_modules: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
# Required for vision/multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: gemma4
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:100]
val_set_size: 0
output_dir: ./outputs/gemma4-e2b-vision-lora
adapter: lora
sequence_len: 2048
pad_to_sequence_len: false
lora_r: 16
lora_alpha: 32
lora_dropout: 0
# Target language model only — vision encoder is frozen via freeze_mm_modules
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
max_steps: 10
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
weight_decay: 0.0
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

View File

@@ -14,8 +14,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
@@ -87,7 +86,7 @@ for more information about using a special vllm-openai docker image for inferenc
Optionally, vLLM can be installed from nightly:
```bash
pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
uv pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
```
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
```bash

View File

@@ -15,8 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
uv pip install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -13,8 +13,7 @@ Tencent released a family of opensource models called HunYuan with varying param
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
uv pip install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -11,7 +11,7 @@ This guide shows how to fine-tune it with Axolotl.
2. Install `timm` for vision model support:
```bash
pip install timm==1.0.19
uv pip install timm==1.0.19
```
3. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.

View File

@@ -14,8 +14,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
```bash
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage

View File

@@ -12,7 +12,7 @@ Before starting, ensure you have:
1. Install the required vision lib:
```bash
pip install 'mistral-common[opencv]==1.8.5'
uv pip install 'mistral-common[opencv]==1.8.5'
```
2. Download the example dataset image:

View File

@@ -23,7 +23,7 @@ Note: This is still experimental given it is based on transformers v5 RC.
git checkout transformers-v5
# Install packages for transformers v5
pip install -e .
uv pip install -e .
```
4. Run the fine-tuning:

View File

@@ -12,7 +12,7 @@ Before starting, ensure you have:
1. Install the required vision lib:
```bash
pip install 'mistral-common[opencv]==1.8.6'
uv pip install 'mistral-common[opencv]==1.8.6'
```
2. Download the example dataset image:

View File

@@ -12,7 +12,7 @@ Before starting, ensure you have:
1. Install the required vision lib:
```bash
pip install 'mistral-common[opencv]==1.8.5'
uv pip install 'mistral-common[opencv]==1.8.5'
```
2. Download the example dataset image:

View File

@@ -13,7 +13,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
3. Install transformers from main
```bash
pip install git+https://github.com/huggingface/transformers.git
uv pip install git+https://github.com/huggingface/transformers.git
```
4. Run one of the example configs:

View File

@@ -1,5 +1,15 @@
base_model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.liger.LigerPlugin
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_rms_norm_gated: true
# LoRA kernel patches are incompatible with this architecture — see README.
lora_mlp_kernel: false
lora_qkv_kernel: false
@@ -22,8 +32,6 @@ dataset_prepared_path: last_run_prepared
sequence_len: 4096
sample_packing: true
use_cut_cross_entropy: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
@@ -31,16 +39,16 @@ lora_r: 16
lora_alpha: 32
lora_dropout: 0.0
lora_target_modules:
# Attention projection layers (present in ~12 attention layers out of 88)
- q_proj
- k_proj
- v_proj
- o_proj
# To also train MoE expert weights, add them via lora_target_parameters
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
# lora_target_parameters:
# - up_proj
# - down_proj
# To also train MoE expert weights, add them via lora_target_parameters
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
# lora_target_parameters:
# - up_proj
# - down_proj
wandb_project:
wandb_entity:

View File

@@ -1,6 +1,16 @@
# See examples/nemotron-h/README.md for architecture notes and requirements.
base_model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.liger.LigerPlugin
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_rms_norm_gated: true
# LoRA kernel patches are incompatible with this architecture — see README.
lora_mlp_kernel: false
lora_qkv_kernel: false
@@ -23,8 +33,6 @@ dataset_prepared_path: last_run_prepared
sequence_len: 4096
sample_packing: true
use_cut_cross_entropy: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
@@ -36,11 +44,12 @@ lora_target_modules:
- k_proj
- v_proj
- o_proj
# To also train MoE expert weights, add them via lora_target_parameters
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
# lora_target_parameters:
# - up_proj
# - down_proj
# To also train MoE expert weights, add them via lora_target_parameters
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
# lora_target_parameters:
# - up_proj
# - down_proj
wandb_project:
wandb_entity:

View File

@@ -12,7 +12,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
3. Install FLA for improved performance
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
uv pip uninstall causal-conv1d && uv pip install flash-linear-attention==0.4.1
```
4. Run the finetuning example:

View File

@@ -26,8 +26,8 @@ sample_packing: true
# Freeze vision encoder
unfrozen_parameters:
- model\.language_model\..*
- lm_head\..*
- model.language_model.*
- lm_head.*
wandb_project:
wandb_entity:

View File

@@ -0,0 +1,62 @@
# Qwen 3.5 35B-A3B MoE Vision LoRA
#
# Vision fine-tuning of the hybrid DeltaNet + Attention MoE model.
# 256 experts, 8 active per token, with early-fusion vision support.
base_model: Qwen/Qwen3.5-35B-A3B
processor_type: AutoProcessor
# Required for vision/multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: qwen3_5
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:100]
val_set_size: 0
output_dir: ./outputs/qwen35-35b-a3b-vision-lora
adapter: lora
sequence_len: 4096
pad_to_sequence_len: false
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
max_steps: 10
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
weight_decay: 0.0
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

View File

@@ -10,7 +10,7 @@
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
uv pip uninstall causal-conv1d && uv pip install flash-linear-attention==0.4.1
```
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.

View File

@@ -11,8 +11,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
# Install Cut Cross Entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -13,14 +13,13 @@ This guide shows how to fine-tune SmolVLM2 models with Axolotl.
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install an extra dependency:
```bash
pip3 install num2words==0.5.14
uv pip install num2words==0.5.14
```
3. Run the finetuning example:

View File

@@ -12,16 +12,15 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Please install the below.
```bash
# audio
pip3 install librosa==0.11.0
pip3 install 'mistral_common[audio]==1.8.3'
uv pip install librosa==0.11.0
uv pip install 'mistral_common[audio]==1.8.3'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -1,15 +1,165 @@
[build-system]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==26.0"]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
build-backend = "setuptools.build_meta"
[project]
name = "axolotl"
dynamic = ["version", "dependencies", "optional-dependencies"]
dynamic = ["version"]
description = "LLM Trainer"
readme = "README.md"
requires-python = ">=3.10"
# license = "Apache-2.0"
dependencies = [
# Core ML stack
"torch>=2.6.0",
"packaging==26.0",
"huggingface_hub>=1.1.7",
"peft>=0.19.1,<0.20.0",
"tokenizers>=0.22.1",
"transformers==5.5.4",
"accelerate==1.13.0",
"datasets>=4.8.4,<4.9.0",
"trl==1.1.0",
"hf_xet==1.4.3",
"kernels==0.13.0",
"trackio>=0.16.1",
"typing-extensions>=4.15.0",
"optimum==1.16.2",
"hf_transfer",
"sentencepiece",
"gradio>=6.2.0,<7.0",
"modal==1.3.0.post1",
"pydantic>=2.10.6",
"addict",
"fire",
"PyYAML>=6.0",
"requests",
"wandb",
"einops",
"colorama",
"numba>=0.61.2",
"numpy>=2.2.6",
# Evaluation & metrics
"evaluate==0.4.1",
"scipy",
"nvidia-ml-py==12.560.30",
"art",
"tensorboard",
"python-dotenv==1.0.1",
# Remote filesystems
"s3fs>=2024.5.0",
"gcsfs>=2025.3.0",
"adlfs>=2024.5.0",
"ocifs==1.3.2",
"zstandard==0.22.0",
"fastcore",
# lm eval harness
"lm_eval==0.4.11",
"langdetect==1.0.9",
"immutabledict==4.2.0",
"antlr4-python3-runtime==4.13.2",
"schedulefree==1.4.1",
"openenv-core==0.1.0",
# Axolotl contribs
"axolotl-contribs-lgpl==0.0.7",
"axolotl-contribs-mit==0.0.6",
# Telemetry
"posthog==6.7.11",
"mistral-common==1.11.0",
# Platform-specific (Linux only)
"bitsandbytes==0.49.1 ; sys_platform != 'darwin'",
"triton>=3.4.0 ; sys_platform != 'darwin'",
"xformers>=0.0.23.post1 ; sys_platform != 'darwin'",
"liger-kernel==0.7.0 ; sys_platform != 'darwin'",
"torchao==0.17.0 ; sys_platform != 'darwin' and platform_machine != 'aarch64'",
# Architecture-specific
"fla-core==0.4.1 ; platform_machine != 'aarch64'",
"flash-linear-attention==0.4.1 ; platform_machine != 'aarch64'",
]
[project.optional-dependencies]
flash-attn = ["flash-attn==2.8.3"]
ring-flash-attn = [
"flash-attn==2.8.3",
"ring-flash-attn>=0.1.7",
]
deepspeed = [
"deepspeed>=0.18.6,<0.19.0",
"deepspeed-kernels",
]
mamba-ssm = [
"mamba-ssm==1.2.0.post1",
"causal_conv1d",
]
auto-gptq = [
"auto-gptq==0.5.1",
]
mlflow = [
"mlflow",
]
galore = [
"galore_torch",
]
apollo = [
"apollo-torch",
]
optimizers = [
"galore_torch",
"apollo-torch",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
"came_pytorch==0.1.3",
]
ray = [
"ray[train]>=2.52.1",
]
vllm = [
"vllm>=0.15.0",
]
llmcompressor = [
"llmcompressor>=0.10.0",
]
fbgemm-gpu = ["fbgemm-gpu-genai>=1.3.0"]
opentelemetry = [
"opentelemetry-api",
"opentelemetry-sdk",
"opentelemetry-exporter-prometheus",
"prometheus-client",
]
[dependency-groups]
dev = [
"black",
"mypy",
"pre-commit",
"types-requests",
"quartodoc",
"jupyter",
"blobfile",
"tiktoken",
]
test = [
"codecov",
"codecov-cli",
"pytest",
"pytest-cov",
"pytest-retry",
"pytest-sugar",
"pytest-xdist",
"tbparse",
]
[project.scripts]
axolotl = "axolotl.cli.main:main"
@@ -18,18 +168,15 @@ Homepage = "https://axolotl.ai/"
Documentation = "https://docs.axolotl.ai/"
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
[tool.setuptools_scm]
[tool.setuptools]
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
include-package-data = true
[tool.setuptools.packages.find]
where = ["src"]
[tool.setuptools.dynamic]
version = { file = "VERSION" }
[tool.setuptools.cmdclass]
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
[tool.ruff]
line-length = 88
target-version = "py310"
@@ -67,5 +214,43 @@ markers = [
"slow: marks tests as slow",
]
# UV specific configuration
[tool.uv]
prerelease = "allow"
conflicts = [
[
{ package = "axolotl" },
{ extra = "vllm" },
],
[
{ package = "axolotl" },
{ extra = "flash-attn" },
],
[
{ package = "axolotl" },
{ extra = "ring-flash-attn" },
],
[
{ package = "axolotl" },
{ extra = "mamba-ssm" },
],
[
{ package = "axolotl" },
{ extra = "auto-gptq" },
],
[
{ package = "axolotl" },
{ extra = "fbgemm-gpu" },
],
[
{ package = "axolotl" },
{ extra = "llmcompressor" },
],
]
[tool.uv.extra-build-dependencies]
axolotl = ["huggingface_hub"]
mamba-ssm = [{ requirement = "torch", match-runtime = true }]
causal-conv1d = [{ requirement = "torch", match-runtime = true }]
flash-attn = [{ requirement = "torch", match-runtime = true }]
deepspeed = [{ requirement = "torch", match-runtime = true }]
auto-gptq = [{ requirement = "torch", match-runtime = true }]

View File

@@ -1,8 +0,0 @@
black
mypy
pre-commit
types-requests
quartodoc
jupyter
blobfile
tiktoken

View File

@@ -1,8 +0,0 @@
codecov
codecov-cli
pytest
pytest-cov
pytest-retry
pytest-sugar
pytest-xdist
tbparse

View File

@@ -1,78 +0,0 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.49.1
triton>=3.4.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
liger-kernel==0.7.0
# END section
packaging==26.0
huggingface_hub>=1.1.7
peft>=0.18.1
tokenizers>=0.22.1
transformers==5.5.0
accelerate==1.13.0
datasets==4.5.0
deepspeed>=0.18.6,<0.19.0
trl==0.29.0
hf_xet==1.3.2
kernels==0.12.2
fla-core==0.4.1
flash-linear-attention==0.4.1
trackio>=0.16.1
typing-extensions>=4.15.0
optimum==1.16.2
hf_transfer
sentencepiece
gradio>=6.2.0,<7.0
modal==1.3.0.post1
pydantic>=2.10.6
addict
fire
PyYAML>=6.0
requests
wandb
einops
colorama
numba>=0.61.2
numpy>=2.2.6
# qlora things
evaluate==0.4.1
scipy
nvidia-ml-py==12.560.30
art
tensorboard
python-dotenv==1.0.1
# remote filesystems
s3fs>=2024.5.0
gcsfs>=2025.3.0
adlfs>=2024.5.0
ocifs==1.3.2
zstandard==0.22.0
fastcore
# lm eval harness
lm_eval==0.4.11
langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.17.0
openenv-core==0.1.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.7
axolotl-contribs-mit==0.0.6
# telemetry
posthog==6.7.11
mistral-common==1.11.0

1518
scripts/analyze_profile.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,479 @@
#!/usr/bin/env python3
"""Build a disposable Hugging Face Kernel Hub package for ScatterMoE LoRA.
This script does not move or edit the in-tree Axolotl kernel sources. It copies
``src/axolotl/integrations/kernels/libs/scattermoe_lora`` into an ignored
build directory and emits a universal HF kernels project that can be pushed to
the Hub.
"""
from __future__ import annotations
import argparse
import fnmatch
import hashlib
import json
import os
import shutil
import subprocess
import sys
from importlib import metadata
from pathlib import Path
PACKAGE_NAME = "scattermoe_lora"
BUILD_VARIANT = "torch-universal"
DEFAULT_REPO_ID = "kernels-community/scattermoe-lora"
HF_REPO_TYPE = "kernel"
HF_KERNEL_URL_PREFIX = "https://hf.co/kernels"
REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_SOURCE_DIR = (
REPO_ROOT / "src" / "axolotl" / "integrations" / "kernels" / "libs" / PACKAGE_NAME
)
DEFAULT_OUTPUT_DIR = REPO_ROOT / "build" / "hf-kernels" / PACKAGE_NAME
EXCLUDED_DIRS = {
"__pycache__",
".mypy_cache",
".pytest_cache",
".ruff_cache",
}
EXCLUDED_FILE_PATTERNS = {
"*.pyc",
"*.pyo",
"*.so",
".DS_Store",
}
TEXT_REPLACEMENTS = {
"from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant import": (
"from .selective_dequant import"
),
"from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant_kernel import": (
"from .selective_dequant_kernel import"
),
"from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import": (
"from .ops import"
),
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Copy Axolotl's ScatterMoE LoRA Triton kernels into a disposable "
"HF Kernel Hub universal package."
)
)
parser.add_argument(
"--source-dir",
type=Path,
default=DEFAULT_SOURCE_DIR,
help=f"ScatterMoE LoRA source package to copy. Default: {DEFAULT_SOURCE_DIR}",
)
parser.add_argument(
"--output-dir",
type=Path,
default=DEFAULT_OUTPUT_DIR,
help=f"Destination build/dist directory. Default: {DEFAULT_OUTPUT_DIR}",
)
parser.add_argument(
"--repo-id",
default=DEFAULT_REPO_ID,
help=f"HF Hub repo id to write into build.toml. Default: {DEFAULT_REPO_ID}",
)
parser.add_argument(
"--version",
type=int,
default=1,
help="Kernel major version written to build.toml and metadata.json.",
)
parser.add_argument(
"--force",
action="store_true",
help="Delete the output directory first if it already exists.",
)
parser.add_argument(
"--no-source-layout",
action="store_true",
help="Only write the shippable build/ tree, not torch-ext/ sources.",
)
parser.add_argument(
"--upload",
action="store_true",
help=(
"Upload the generated universal kernel package with huggingface_hub. "
"This bypasses kernel-builder and is intended for pure Python/Triton "
"universal kernels."
),
)
parser.add_argument(
"--private",
action="store_true",
help="Create the HF Hub repo as private when used with --upload.",
)
parser.add_argument(
"--skip-version-branch",
action="store_true",
help="With --upload, only upload main and skip the v<version> branch.",
)
return parser.parse_args()
def should_skip_file(path: Path) -> bool:
return any(
fnmatch.fnmatch(path.name, pattern) for pattern in EXCLUDED_FILE_PATTERNS
)
def iter_source_files(source_dir: Path) -> list[Path]:
files: list[Path] = []
for root, dirs, filenames in os.walk(source_dir):
dirs[:] = sorted(d for d in dirs if d not in EXCLUDED_DIRS)
for filename in sorted(filenames):
path = Path(root) / filename
if not should_skip_file(path):
files.append(path)
return files
def content_hash(source_dir: Path) -> str:
digest = hashlib.sha1()
for path in iter_source_files(source_dir):
rel = path.relative_to(source_dir).as_posix()
digest.update(rel.encode("utf-8"))
digest.update(b"\0")
digest.update(path.read_bytes())
digest.update(b"\0")
return digest.hexdigest()[:10]
def git_revision() -> str:
try:
result = subprocess.run(
["git", "rev-parse", "--short", "HEAD"],
cwd=REPO_ROOT,
check=True,
capture_output=True,
text=True,
)
except (OSError, subprocess.CalledProcessError):
return "unknown"
return result.stdout.strip() or "unknown"
def transform_python_source(text: str, rel_path: Path, op_namespace: str) -> str:
for old, new in TEXT_REPLACEMENTS.items():
text = text.replace(old, new)
if rel_path.as_posix() == "gemma4_experts.py":
text = text.replace(
" from axolotl.integrations.kernels.constants import resolve_experts_class",
(
" raise RuntimeError(\n"
' "patch_gemma4_scattermoe is only available from the in-tree Axolotl "\n'
' "integration. Use register_scattermoe_experts() with the standalone "\n'
' "HF kernel package."\n'
" )"
),
)
return text.replace("scattermoe::", f"{op_namespace}::")
def copy_package(source_dir: Path, package_dir: Path, op_namespace: str) -> None:
for source in iter_source_files(source_dir):
rel_path = source.relative_to(source_dir)
destination = package_dir / rel_path
destination.parent.mkdir(parents=True, exist_ok=True)
if source.suffix == ".py":
text = source.read_text(encoding="utf-8")
text = transform_python_source(text, rel_path, op_namespace)
destination.write_text(text, encoding="utf-8")
else:
shutil.copy2(source, destination)
write_ops_module(package_dir / "_ops.py", op_namespace)
def write_ops_module(path: Path, op_namespace: str) -> None:
path.write_text(
"\n".join(
[
"import torch",
"",
f"ops = torch.ops.{op_namespace}",
"",
"",
"def add_op_namespace_prefix(op_name: str) -> str:",
f' return f"{op_namespace}::{{op_name}}"',
"",
]
),
encoding="utf-8",
)
def write_build_toml(path: Path, repo_id: str, version: int) -> None:
lines = [
"[general]",
f'name = "{PACKAGE_NAME}"',
"universal = true",
f"version = {version}",
"",
]
if repo_id:
lines.extend(
[
"[general.hub]",
f'repo-id = "{repo_id}"',
"",
]
)
path.write_text("\n".join(lines), encoding="utf-8")
def write_flake(path: Path) -> None:
path.write_text(
"""{
description = "Flake for scattermoe_lora kernel";
inputs = {
builder.url = "github:huggingface/kernels";
};
outputs =
{
self,
builder,
}:
builder.lib.genKernelFlakeOutputs {
inherit self;
path = ./.;
};
}
""",
encoding="utf-8",
)
def write_readme(path: Path, repo_id: str, source_hash: str, op_namespace: str) -> None:
repo_display = repo_id or "<your-org>/scattermoe-lora"
path.write_text(
f"""---
library_name: kernels
license: apache-2.0
tags:
- kernel
- kernels
---
# ScatterMoE LoRA
Standalone Hugging Face Kernel Hub package for Axolotl's ScatterMoE LoRA Triton kernels.
This package is generated from Axolotl's in-tree `scattermoe_lora` sources and is exported as a universal kernel because the implementation is Python/Triton rather than a precompiled C++/CUDA extension.
```python
from kernels import get_kernel
scattermoe_lora = get_kernel("{repo_display}")
```
Export metadata:
- source package: `src/axolotl/integrations/kernels/libs/scattermoe_lora`
- source revision: `{git_revision()}`
- source content hash: `{source_hash}`
- torch custom op namespace: `{op_namespace}`
The generated `build/torch-universal/{PACKAGE_NAME}` directory is the shippable Hub artifact. `torch-ext/{PACKAGE_NAME}` is included so `kernel-builder build-and-copy` can regenerate the universal build tree if desired.
""",
encoding="utf-8",
)
def write_metadata(path: Path, version: int) -> None:
path.write_text(
json.dumps({"version": version}, indent=2, sort_keys=True) + "\n",
encoding="utf-8",
)
def prepare_output_dir(output_dir: Path, force: bool) -> None:
if output_dir.exists():
if not force:
raise FileExistsError(
f"{output_dir} already exists. Re-run with --force to replace it."
)
shutil.rmtree(output_dir)
output_dir.mkdir(parents=True)
def build_package(args: argparse.Namespace) -> Path:
source_dir = args.source_dir.resolve()
output_dir = args.output_dir.resolve()
if not source_dir.is_dir():
raise FileNotFoundError(f"source package does not exist: {source_dir}")
if not (source_dir / "__init__.py").is_file():
raise FileNotFoundError(f"source package is missing __init__.py: {source_dir}")
source_hash = content_hash(source_dir)
op_namespace = f"_{PACKAGE_NAME}_{source_hash}"
prepare_output_dir(output_dir, args.force)
write_build_toml(output_dir / "build.toml", args.repo_id, args.version)
write_flake(output_dir / "flake.nix")
write_readme(output_dir / "README.md", args.repo_id, source_hash, op_namespace)
if not args.no_source_layout:
copy_package(source_dir, output_dir / "torch-ext" / PACKAGE_NAME, op_namespace)
build_package_dir = output_dir / "build" / BUILD_VARIANT / PACKAGE_NAME
copy_package(source_dir, build_package_dir, op_namespace)
write_metadata(build_package_dir.parent / "metadata.json", args.version)
return output_dir
def upload_package(args: argparse.Namespace, output_dir: Path) -> None:
if not args.repo_id:
raise ValueError("--repo-id is required when using --upload")
try:
from huggingface_hub import HfApi, constants as hf_constants
except ImportError as exc:
raise RuntimeError(
"--upload requires huggingface_hub. Install it or run the upload "
"manually with the Hugging Face CLI."
) from exc
try:
hub_version = metadata.version("huggingface_hub")
except metadata.PackageNotFoundError:
hub_version = "unknown"
accepted_repo_types = getattr(
hf_constants,
"REPO_TYPES_WITH_KERNEL",
getattr(hf_constants, "REPO_TYPES", ()),
)
if HF_REPO_TYPE not in accepted_repo_types:
raise RuntimeError(
"Your huggingface_hub installation does not support "
f"repo_type={HF_REPO_TYPE!r} (found huggingface_hub {hub_version}). "
f"Upgrade this interpreter with: {sys.executable} -m pip install --upgrade "
"'huggingface_hub>=1.10.0'"
)
# huggingface_hub 1.11.0 has partial kernel support: create_repo accepts
# "kernel", but upload_folder/create_commit still validate against the
# older REPO_TYPES list. Extend it in-process so those helpers use the
# /api/kernels/... endpoints until upstream broadens that check.
if HF_REPO_TYPE not in hf_constants.REPO_TYPES:
hf_constants.REPO_TYPES.append(HF_REPO_TYPE)
api = HfApi()
try:
repo_id = api.create_repo(
repo_id=args.repo_id,
repo_type=HF_REPO_TYPE,
private=args.private,
exist_ok=True,
).repo_id
except ValueError as exc:
if "Invalid repo type" in str(exc):
raise RuntimeError(
"huggingface_hub rejected repo_type='kernel'. "
f"This usually means the command is running with an older Hub "
f"client than expected (found huggingface_hub {hub_version} at "
f"{sys.executable}). Upgrade with: {sys.executable} -m pip "
"install --upgrade 'huggingface_hub>=1.10.0'"
) from exc
raise
delete_patterns = [
"build/**",
"torch-ext/**",
"build.toml",
"flake.nix",
"README.md",
]
api.upload_folder(
repo_id=repo_id,
repo_type=HF_REPO_TYPE,
folder_path=output_dir,
revision="main",
delete_patterns=delete_patterns,
commit_message="Upload ScatterMoE LoRA universal kernel",
)
print(f"Uploaded main branch: {HF_KERNEL_URL_PREFIX}/{repo_id}")
if args.skip_version_branch:
return
version_branch = f"v{args.version}"
api.create_branch(
repo_id=repo_id,
repo_type=HF_REPO_TYPE,
branch=version_branch,
revision="main",
exist_ok=True,
)
api.upload_folder(
repo_id=repo_id,
repo_type=HF_REPO_TYPE,
folder_path=output_dir,
revision=version_branch,
delete_patterns=delete_patterns,
commit_message=f"Upload ScatterMoE LoRA universal kernel {version_branch}",
)
print(
f"Uploaded version branch: "
f"{HF_KERNEL_URL_PREFIX}/{repo_id}/tree/{version_branch}"
)
def main() -> int:
args = parse_args()
try:
output_dir = build_package(args)
if args.upload:
upload_package(args, output_dir)
except Exception as exc:
print(f"error: {exc}", file=sys.stderr)
return 1
print(f"Wrote ScatterMoE LoRA HF kernel package to: {output_dir}")
print(f"Shippable artifact: {output_dir / 'build' / BUILD_VARIANT / PACKAGE_NAME}")
if args.upload:
print(f'Load it with: get_kernel("{args.repo_id}", version={args.version})')
print(f"Uploaded as Hugging Face repo_type={HF_REPO_TYPE!r}.")
return 0
print("Next step:")
print(" upload this universal Python/Triton kernel directly:")
print(
f" python3 {Path(__file__).as_posix()} "
f"--repo-id {args.repo_id} --force --upload"
)
if shutil.which("kernel-builder") is None:
print(" optional: install kernel-builder for full Nix-based builds:")
print(
" curl -fsSL "
"https://raw.githubusercontent.com/huggingface/kernels/main/install.sh "
"| bash"
)
else:
print(" optional: upload with kernel-builder:")
print(f" cd {output_dir}")
print(" kernel-builder build-and-upload")
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"'
)

View File

@@ -1,40 +0,0 @@
# noqa
import sys
try:
import torch
except ImportError as error:
raise ImportError("Install torch via `pip install torch`") from error
from packaging.version import Version as V
use_uv = "--uv" in sys.argv[1:]
v = V(torch.__version__)
cuda = str(torch.version.cuda)
try:
is_ampere = torch.cuda.get_device_capability()[0] >= 8
except RuntimeError:
is_ampere = False
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
raise RuntimeError(f"CUDA = {cuda} not supported!")
if v <= V("2.1.0"):
raise RuntimeError(f"Torch = {v} too old!")
elif v <= V("2.1.1"):
x = "cu{}{}-torch211"
elif v <= V("2.1.2"):
x = "cu{}{}-torch212"
elif v < V("2.3.0"):
x = "cu{}{}-torch220"
elif v < V("2.4.0"):
x = "cu{}{}-torch230"
elif v < V("2.5.0"):
x = "cu{}{}-torch240"
elif v < V("2.6.0"):
x = "cu{}{}-torch250"
else:
raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
uv_prefix = "uv " if use_uv else ""
print(
f'{uv_prefix}pip install unsloth-zoo==2024.12.1 && {uv_prefix}pip install --no-deps "unsloth[{x}]==2024.12.4"'
)

230
setup.py
View File

@@ -1,230 +0,0 @@
"""setup.py for axolotl"""
import os
import platform
import re
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from setuptools import find_packages, setup
def parse_requirements(extras_require_map):
_install_requires = []
_dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [r.strip() for r in requirements_file.readlines()]
for line in lines:
is_extras = "deepspeed" in line or "mamba-ssm" in line
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
elif not is_extras and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
install_xformers = platform.machine() != "aarch64"
if platform.machine() == "aarch64":
# skip on ARM64
skip_packages = [
"torchao",
"fla-core",
"flash-linear-attention",
]
_install_requires = [
req
for req in _install_requires
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
]
if "Darwin" in platform.system():
# skip packages not compatible with OSX
skip_packages = [
"bitsandbytes",
"triton",
"mamba-ssm",
"xformers",
"liger-kernel",
]
_install_requires = [
req
for req in _install_requires
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
]
print(
_install_requires, [req in skip_packages for req in _install_requires]
)
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
try:
torch_version = version("torch")
except PackageNotFoundError:
torch_version = "2.8.0" # default to torch 2.8.0
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if version_match:
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = (
int(patch) if patch is not None else 0
) # Default patch to 0 if not present
else:
raise ValueError("Invalid version format")
torch_parts = torch_version.split("+")
if len(torch_parts) == 2:
torch_cuda_version = torch_parts[1]
_dependency_links.append(
f"https://download.pytorch.org/whl/{torch_cuda_version}"
)
if (major, minor) >= (2, 10):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [
"fbgemm-gpu==1.5.0",
"fbgemm-gpu-genai==1.5.0",
]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm>=0.17.1"]
elif (major, minor) >= (2, 9):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [
"fbgemm-gpu==1.4.0",
"fbgemm-gpu-genai==1.4.2",
]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
extras_require_map["vllm"] = ["vllm==0.13.0"]
else:
extras_require_map["vllm"] = ["vllm==0.14.0"]
elif (major, minor) >= (2, 8):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
extras_require_map["vllm"] = ["vllm==0.11.0"]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
elif (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
if install_xformers:
_install_requires.append("xformers==0.0.30")
# vllm 0.9.x is incompatible with latest transformers
extras_require_map.pop("vllm")
else:
if install_xformers:
_install_requires.append("xformers==0.0.31")
extras_require_map["vllm"] = ["vllm==0.10.1"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
if install_xformers:
_install_requires.append("xformers==0.0.29.post3")
# since we only support 2.6.0+cu126
_dependency_links.append("https://download.pytorch.org/whl/cu126")
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if install_xformers:
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers>=0.0.28.post3")
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 4):
extras_require_map.pop("vllm")
if install_xformers:
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
else:
raise ValueError("axolotl requires torch>=2.4")
except PackageNotFoundError:
pass
return _install_requires, _dependency_links, extras_require_map
def get_package_version():
with open(
Path(os.path.dirname(os.path.abspath(__file__))) / "VERSION",
"r",
encoding="utf-8",
) as fin:
version_ = fin.read().strip()
return version_
extras_require = {
"flash-attn": ["flash-attn==2.8.3"],
"ring-flash-attn": [
"flash-attn==2.8.3",
"ring-flash-attn>=0.1.7",
],
"deepspeed": [
"deepspeed==0.18.2",
"deepspeed-kernels",
],
"mamba-ssm": [
"mamba-ssm==1.2.0.post1",
"causal_conv1d",
],
"auto-gptq": [
"auto-gptq==0.5.1",
],
"mlflow": [
"mlflow",
],
"galore": [
"galore_torch",
],
"apollo": [
"apollo-torch",
],
"optimizers": [
"galore_torch",
"apollo-torch",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
"came_pytorch==0.1.3",
],
"ray": [
"ray[train]>=2.52.1",
],
"vllm": [
"vllm==0.10.0",
],
"llmcompressor": [
"llmcompressor==0.5.1",
],
"fbgemm-gpu": ["fbgemm-gpu-genai==1.3.0"],
"opentelemetry": [
"opentelemetry-api",
"opentelemetry-sdk",
"opentelemetry-exporter-prometheus",
"prometheus-client",
],
}
install_requires, dependency_links, extras_require_build = parse_requirements(
extras_require
)
setup(
version=get_package_version(),
package_dir={"": "src"},
packages=find_packages("src"),
install_requires=install_requires,
dependency_links=dependency_links,
entry_points={
"console_scripts": [
"axolotl=axolotl.cli.main:main",
],
},
extras_require=extras_require_build,
)

View File

@@ -0,0 +1,108 @@
"""Bundled agent documentation for axolotl.
These docs are optimized for consumption by AI coding agents.
The source of truth is docs/agents/*.md and AGENTS.md in the repo root.
This module resolves those paths at runtime — no files are duplicated
into the package.
For pip-only installs (no repo checkout), run `axolotl fetch docs` first
to download the docs locally.
"""
from pathlib import Path
# Topic name -> (filename in docs/agents/, fallback filename for AGENTS.md)
TOPICS = {
"overview": "AGENTS.md",
"sft": "docs/agents/sft.md",
"grpo": "docs/agents/grpo.md",
"preference_tuning": "docs/agents/preference_tuning.md",
"reward_modelling": "docs/agents/reward_modelling.md",
"pretraining": "docs/agents/pretraining.md",
"model_architectures": "docs/agents/model_architectures.md",
"new_model_support": "docs/agents/new_model_support.md",
}
def _find_repo_root() -> Path | None:
"""Walk up from this file to find the repo root (contains AGENTS.md)."""
# In an editable install or repo checkout, walk up from
# src/axolotl/cli/agent_docs/ to find the repo root
current = Path(__file__).resolve().parent
while current != current.parent:
if (current / "AGENTS.md").exists() and (current / "docs" / "agents").is_dir():
return current
current = current.parent
return None
def _find_docs_dir() -> Path | None:
"""Find a fetched docs directory (from `axolotl fetch docs`)."""
# axolotl fetch docs --dest defaults to ./docs/ in cwd
cwd_docs = Path.cwd() / "docs" / "agents"
if cwd_docs.is_dir():
return Path.cwd()
return None
def _resolve_path(topic: str) -> Path:
"""Resolve a topic name to the actual file path."""
if topic not in TOPICS:
available = ", ".join(sorted(TOPICS.keys()))
raise FileNotFoundError(f"Unknown topic: {topic!r}. Available: {available}")
relative_path = TOPICS[topic]
# Try repo root first (editable install / repo checkout)
repo_root = _find_repo_root()
if repo_root:
candidate = repo_root / relative_path
if candidate.exists():
return candidate
# Try cwd (fetched docs via `axolotl fetch docs`)
docs_root = _find_docs_dir()
if docs_root:
candidate = docs_root / relative_path
if candidate.exists():
return candidate
# Also check cwd directly for AGENTS.md
if topic == "overview":
cwd_agents = Path.cwd() / "AGENTS.md"
if cwd_agents.exists():
return cwd_agents
raise FileNotFoundError(
f"Could not find {relative_path!r}. "
f"If you installed axolotl via pip, run `axolotl fetch docs` first "
f"to download the documentation."
)
def get_doc(topic: str = "overview") -> str:
"""Return the content of an agent doc by topic name.
Args:
topic: One of the keys in TOPICS, or "overview" (default).
Returns:
The markdown content of the doc.
Raises:
FileNotFoundError: If the topic can't be found.
"""
return _resolve_path(topic).read_text()
def list_topics() -> dict[str, str]:
"""Return a dict of topic name -> first line (title) of each doc."""
result = {}
for topic in sorted(TOPICS.keys()):
try:
path = _resolve_path(topic)
first_line = path.read_text().split("\n", 1)[0].lstrip("# ").strip()
result[topic] = first_line
except FileNotFoundError:
result[topic] = "(not found — run `axolotl fetch docs`)"
return result

View File

@@ -294,7 +294,9 @@ def merge_lora(config: str, **kwargs):
@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.argument(
"directory", type=click.Choice(["examples", "deepspeed_configs", "docs"])
)
@click.option("--dest", help="Destination directory")
def fetch(directory: str, dest: Optional[str]):
"""
@@ -303,9 +305,10 @@ def fetch(directory: str, dest: Optional[str]):
Available directories:
- examples: Example configuration files
- deepspeed_configs: DeepSpeed configuration files
- docs: Full documentation (Quarto markdown files)
Args:
directory: One of `examples`, `deepspeed_configs`.
directory: One of `examples`, `deepspeed_configs`, `docs`.
dest: Optional destination directory.
"""
fetch_from_github(f"{directory}/", dest)
@@ -340,6 +343,112 @@ def delinearize_llama4(model: str, output: str):
do_delinearize_llama4(model, output)
@cli.command("agent-docs")
@click.argument("topic", required=False, default=None)
@click.option("--list", "list_topics", is_flag=True, help="List available topics")
def agent_docs(topic: Optional[str], list_topics: bool):
"""Show agent-optimized documentation.
Prints reference docs designed for AI coding agents.
These docs are bundled with the package — no network access needed.
\b
Examples:
axolotl agent-docs # overview (start here)
axolotl agent-docs grpo # GRPO reference
axolotl agent-docs sft # SFT reference
axolotl agent-docs --list # list all topics
"""
from axolotl.cli.agent_docs import get_doc, list_topics as _list_topics
if list_topics:
for name, title in _list_topics().items():
click.echo(f" {name:25s} {title}")
return
if topic is None:
topic = "overview"
try:
click.echo(get_doc(topic))
except FileNotFoundError as exc:
raise click.BadParameter(str(exc)) from exc
@cli.command("config-schema")
@click.option(
"--format",
"output_format",
type=click.Choice(["json", "yaml"]),
default="json",
help="Output format (default: json)",
)
@click.option("--field", help="Show schema for a specific field only")
def config_schema(output_format: str, field: Optional[str]):
"""Dump the full config JSON schema.
Useful for AI agents and tooling to discover all available config options,
their types, defaults, and descriptions.
\b
Examples:
axolotl config-schema # full JSON schema
axolotl config-schema --format yaml # YAML format
axolotl config-schema --field adapter # single field
"""
import json
try:
schema = AxolotlInputConfig.model_json_schema()
except (TypeError, ValueError, AttributeError) as exc:
# Fallback: dump field names, types, and defaults when full schema
# generation fails (e.g. torch.dtype not JSON-serializable)
LOG.warning(
"Full JSON schema generation failed, using simplified fallback: %s", exc
)
fields = {}
for name, field_info in AxolotlInputConfig.model_fields.items():
entry = {}
if field_info.description:
entry["description"] = field_info.description
if field_info.default is not None:
try:
json.dumps(field_info.default)
entry["default"] = field_info.default
except (TypeError, ValueError):
entry["default"] = str(field_info.default)
annotation = field_info.annotation
if annotation is not None:
entry["type"] = str(annotation)
fields[name] = entry
schema = {
"properties": fields,
"_note": "simplified schema (full generation failed)",
}
if field:
props = schema.get("properties", {})
if field not in props:
# Try case-insensitive match
matches = [k for k in props if k.lower() == field.lower()]
if matches:
field = matches[0]
else:
raise click.BadParameter(
f"Unknown field: {field!r}. "
f"Omit --field to dump the full schema, "
f"or pipe to jq: axolotl config-schema | jq '.properties | keys'"
)
schema = {field: props[field]}
if output_format == "yaml":
import yaml # pylint: disable=import-outside-toplevel
click.echo(yaml.dump(schema, default_flow_style=False, sort_keys=False))
else:
click.echo(json.dumps(schema, indent=2))
cli.add_command(lm_eval)

View File

@@ -115,6 +115,7 @@ def _do_merge_lora_efficient(*, cfg: DictDefault) -> None:
simulate_nf4_experts=simulate_nf4_experts,
nf4_blocksize=nf4_blocksize,
nf4_double_quant=nf4_double_quant,
trust_remote_code=bool(getattr(cfg, "trust_remote_code", False)),
)
LOG.debug("Memory-efficient LoRA merge completed successfully!")

View File

@@ -17,6 +17,93 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _build_layer_type_map(
base_model_path: Path, trust_remote_code: bool = False
) -> dict[str, str]:
"""Build a map of module_name -> layer_type using a meta-device model.
Instantiates the model architecture on the meta device (zero memory)
to inspect which modules are Linear vs Conv1d/Conv2d/Conv3d.
This avoids relying on weight tensor ndim heuristics.
"""
import json as _json
import torch.nn as nn
from transformers import AutoConfig
config_path = base_model_path / "config.json"
if not config_path.exists():
return {}
try:
with open(config_path) as f:
model_config = _json.load(f)
except (OSError, _json.JSONDecodeError):
return {}
architectures = model_config.get("architectures", [])
if not architectures:
return {}
try:
config = AutoConfig.from_pretrained(
str(base_model_path), trust_remote_code=trust_remote_code
)
except Exception:
LOG.debug("Could not load config for layer type introspection")
return {}
# Determine the right Auto class from architectures
from transformers import (
AutoModel,
AutoModelForCausalLM,
)
auto_classes = [AutoModelForCausalLM, AutoModel]
try:
from transformers import AutoModelForImageTextToText
auto_classes.insert(0, AutoModelForImageTextToText)
except ImportError:
pass
model = None
for auto_cls in auto_classes:
try:
with torch.device("meta"):
model = auto_cls.from_config(
config, trust_remote_code=trust_remote_code
)
break
except Exception: # noqa: BLE001
LOG.debug(
"Could not instantiate meta model with %s, trying next",
auto_cls.__name__,
)
if model is None:
LOG.debug("Could not instantiate meta model for layer type introspection")
return {}
layer_types = {}
for name, module in model.named_modules():
if isinstance(module, nn.Conv3d):
layer_types[name] = "Conv3d"
elif isinstance(module, nn.Conv2d):
layer_types[name] = "Conv2d"
elif isinstance(module, nn.Conv1d):
layer_types[name] = "Conv1d"
elif isinstance(module, nn.Linear):
layer_types[name] = "Linear"
del model
LOG.debug(
f"Layer type map: {len(layer_types)} modules "
f"({sum(1 for v in layer_types.values() if 'Conv' in v)} conv layers)"
)
return layer_types
def _simulate_nf4_roundtrip(
tensor: torch.Tensor,
blocksize: Optional[int] = None,
@@ -191,6 +278,7 @@ def _build_peft_layer_and_get_delta(
adapter_name: str = "default",
is_param_wrapper: bool = False,
magnitude: Optional[torch.Tensor] = None,
layer_type: Optional[str] = None,
) -> torch.Tensor:
"""
Use PEFT's own layer classes to compute the LoRA delta weight.
@@ -211,7 +299,7 @@ def _build_peft_layer_and_get_delta(
out_features = lora_b.shape[0]
lora_alpha = lora_config_dict.get("lora_alpha", lora_config_dict.get("r", 1))
use_rslora = bool(lora_config_dict.get("use_rslora", False))
use_dora = bool(lora_config_dict.get("use_dora", False)) and magnitude is not None
use_dora = bool(lora_config_dict.get("use_dora", False))
if is_param_wrapper:
from peft.tuners.lora.layer import ParamWrapper
@@ -227,18 +315,110 @@ def _build_peft_layer_and_get_delta(
"weight", nn.Parameter(base_tensor.clone(), requires_grad=False)
)
# ParamWrapper rejects dropout/fan_in_fan_out/lora_bias/use_dora, so
# build a minimal config with only the fields it accepts.
pw_config = LoraConfig(
r=r,
lora_alpha=lora_alpha,
lora_dropout=0.0,
fan_in_fan_out=False,
use_rslora=use_rslora,
use_dora=False,
lora_bias=False,
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
layer = ParamWrapper(
fake,
adapter_name=adapter_name,
parameter_name="weight",
config=pw_config,
r=r,
lora_alpha=lora_alpha,
use_rslora=use_rslora,
)
layer.lora_A[adapter_name].weight.data = lora_a
layer.lora_B[adapter_name].weight.data = lora_b
delta = layer.get_delta_weight(adapter_name)
# peft >=0.19.1 may return delta with transposed dims for 3D params
if delta.shape != base_tensor.shape and delta.ndim == 3:
delta = delta.transpose(1, 2).contiguous()
return delta
elif (
layer_type and "Conv" in layer_type or (layer_type is None and lora_a.ndim > 2)
):
# Conv layer detected via model introspection (or ndim fallback)
from peft.tuners.lora import layer as peft_lora_layer
# Determine conv type from layer_type map or fall back to ndim
if layer_type and "Conv" in layer_type:
conv_type: str = layer_type
else:
ndim = lora_a.ndim
_conv_map = {3: "Conv1d", 4: "Conv2d", 5: "Conv3d"}
if ndim not in _conv_map:
raise ValueError(
f"Unsupported LoRA weight dimensionality {ndim} for conv layer"
)
conv_type = _conv_map[ndim]
LOG.warning(
f"Using ndim-based fallback for conv detection (ndim={ndim}). "
f"Consider providing layer_type from meta-device introspection."
)
conv_cls_map = {"Conv1d": nn.Conv1d, "Conv2d": nn.Conv2d, "Conv3d": nn.Conv3d}
ConvCls = conv_cls_map[conv_type]
PeftConvCls = getattr(peft_lora_layer, conv_type)
# Reconstruct conv parameters from base tensor and lora_a shapes
# base_tensor: [out_channels, in_channels/groups, *kernel_size]
# lora_a: [r, in_channels/groups, *kernel_size]
# lora_b: [out_channels, r, *ones]
out_channels = base_tensor.shape[0]
in_channels = base_tensor.shape[1]
kernel_size = tuple(base_tensor.shape[2:])
stride = (1,) * (base_tensor.ndim - 2)
padding = (0,) * (base_tensor.ndim - 2)
base_layer = ConvCls(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=False,
)
base_layer.weight.data = base_tensor.clone()
conv_config = LoraConfig(
r=r_total,
lora_alpha=lora_alpha,
use_rslora=use_rslora,
use_dora=use_dora,
)
layer = PeftConvCls(
base_layer,
adapter_name=adapter_name,
config=conv_config,
r=r_total,
lora_alpha=lora_alpha,
)
layer.lora_A[adapter_name].weight.data = lora_a
layer.lora_B[adapter_name].weight.data = lora_b
if use_dora:
if magnitude is None:
raise ValueError(
f"DoRA merge requires a magnitude vector but none was found "
f"for conv layer (adapter={adapter_name}). Check that the "
f"adapter checkpoint contains lora_magnitude_vector weights."
)
mag_layer = layer.lora_magnitude_vector[adapter_name]
mag_layer.weight = nn.Parameter(magnitude)
layer.merge(adapter_names=[adapter_name])
return base_layer.weight.data - base_tensor
return layer.get_delta_weight(adapter_name)
else:
from peft.tuners.lora.layer import Linear as LoraLinear
@@ -251,15 +431,20 @@ def _build_peft_layer_and_get_delta(
or lora_config_dict.get("lora_fan_in_fan_out", False)
)
layer = LoraLinear(
base_layer,
adapter_name=adapter_name,
linear_config = LoraConfig(
r=r_total,
lora_alpha=lora_alpha,
fan_in_fan_out=fan_in_fan_out,
use_rslora=use_rslora,
use_dora=use_dora,
)
layer = LoraLinear(
base_layer,
adapter_name=adapter_name,
config=linear_config,
r=r_total,
lora_alpha=lora_alpha,
)
layer.lora_A[adapter_name].weight.data = lora_a
layer.lora_B[adapter_name].weight.data = lora_b
@@ -267,6 +452,12 @@ def _build_peft_layer_and_get_delta(
# DoRA merges magnitude normalization into the weight directly.
# Use PEFT's merge() which handles DoRA internally, then
# compute the delta as merged_weight - original_weight.
if magnitude is None:
raise ValueError(
f"DoRA merge requires a magnitude vector but none was found "
f"for linear layer (adapter={adapter_name}). Check that the "
f"adapter checkpoint contains lora_magnitude_vector weights."
)
mag_layer = layer.lora_magnitude_vector[adapter_name]
mag_layer.weight = nn.Parameter(magnitude)
layer.merge(adapter_names=[adapter_name])
@@ -382,6 +573,7 @@ def _merge_tensor_with_lora(
nf4_double_quant: bool = True,
use_dora: bool = False,
weight_renamings: Optional[Dict[str, str]] = None,
layer_type_map: Optional[Dict[str, str]] = None,
) -> tuple[torch.Tensor, bool]:
"""
Helper function to merge a single tensor with its corresponding LoRA weights.
@@ -426,12 +618,30 @@ def _merge_tensor_with_lora(
if use_dora
else None
)
# Look up layer type from meta-device model introspection
_layer_type = None
if layer_type_map:
mod_path = key.rsplit(".weight", 1)[0] if key.endswith(".weight") else key
_layer_type = layer_type_map.get(mod_path)
# Try common prefix variations (e.g. with/without "model." prefix)
if _layer_type is None:
for prefix in [
"model.",
"model.language_model.",
"model.language_model.model.",
]:
_layer_type = layer_type_map.get(prefix + mod_path)
if _layer_type:
break
delta = _build_peft_layer_and_get_delta(
lora_a.to(device),
lora_b.to(device),
lora_config_dict,
tensor.to(device),
magnitude=magnitude.to(device) if magnitude is not None else None,
layer_type=_layer_type,
)
merged_tensor = (
(tensor.to(device).to(torch.float32) + delta.to(torch.float32))
@@ -556,6 +766,7 @@ def _fuse_and_unfuse_with_merge(
nf4_double_quant: bool = True,
use_dora: bool = False,
weight_renamings: Optional[Dict[str, str]] = None,
layer_type_map: Optional[Dict[str, str]] = None,
) -> tuple[Dict[str, torch.Tensor], int, set]:
"""
For tensors matching WeightConverter patterns (MoE expert weights):
@@ -696,12 +907,32 @@ def _fuse_and_unfuse_with_merge(
if use_dora
else None
)
# Look up layer type for the fused key
_layer_type = None
if layer_type_map:
mod_path = (
fused_key.rsplit(".weight", 1)[0]
if fused_key.endswith(".weight")
else fused_key
)
_layer_type = layer_type_map.get(mod_path)
if _layer_type is None:
for prefix in [
"model.",
"model.language_model.",
"model.language_model.model.",
]:
_layer_type = layer_type_map.get(prefix + mod_path)
if _layer_type:
break
delta = _build_peft_layer_and_get_delta(
lora_a.to(device),
lora_b.to(device),
lora_config_dict,
fused_tensor.to(device),
magnitude=magnitude.to(device) if magnitude is not None else None,
layer_type=_layer_type,
)
fused_tensor = (
(
@@ -740,6 +971,7 @@ def merge_lora_sharded_efficient(
simulate_nf4_experts: bool = False,
nf4_blocksize: Optional[int] = None,
nf4_double_quant: bool = True,
trust_remote_code: bool = False,
) -> None:
"""
Memory-efficient LoRA merging that processes shards individually
@@ -750,6 +982,8 @@ def merge_lora_sharded_efficient(
simulate_nf4_experts: Apply NF4 roundtrip only to MoE expert tensors
(for quantize_moe_experts). Expert tensors are identified by having
"expert" in the key name and ndim >= 3.
trust_remote_code: Whether to trust remote code when loading model
config for layer-type introspection. Defaults to False for safety.
"""
base_model_path = Path(base_model_path)
lora_adapter_path = Path(lora_adapter_path)
@@ -780,6 +1014,10 @@ def merge_lora_sharded_efficient(
use_dora = bool(lora_config_dict.get("use_dora", False))
# Build layer type map via meta-device model introspection
layer_type_map = _build_layer_type_map(
base_model_path, trust_remote_code=trust_remote_code
)
unsupported_methods = []
# Check for AdaLoRA (Adaptive LoRA)
@@ -904,6 +1142,7 @@ def merge_lora_sharded_efficient(
nf4_double_quant=nf4_double_quant,
use_dora=use_dora,
weight_renamings=weight_renamings,
layer_type_map=layer_type_map,
)
merged_count += fused_merged
@@ -926,6 +1165,7 @@ def merge_lora_sharded_efficient(
nf4_double_quant=nf4_double_quant,
use_dora=use_dora,
weight_renamings=weight_renamings,
layer_type_map=layer_type_map,
)
merged_tensors[key] = merged_tensor
if was_merged:

View File

@@ -41,6 +41,7 @@ from axolotl.utils.callbacks import (
GCCallback,
SaveAxolotlConfigtoWandBCallback,
SaveModelOnFirstStepCallback,
SkipEvalOnResumeCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.distributed import build_parallelism_config
@@ -118,6 +119,9 @@ class TrainerBuilderBase(abc.ABC):
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.resume_from_checkpoint:
callbacks.append(SkipEvalOnResumeCallback())
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))

View File

@@ -100,6 +100,27 @@ class AxolotlTrainer(
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
# Gemma4 (and similar multimodal models) declare **kwargs in forward() for
# extra inputs like mm_token_type_ids. HF Trainer interprets VAR_KEYWORD as
# "the model handles num_items_in_batch internally" and skips the loss ÷
# gradient_accumulation_steps normalisation, which inflates the *logged* loss
# (the gradient itself is still correct). Override to False when the model
# doesn't actually consume num_items_in_batch.
if self.model_accepts_loss_kwargs:
model_to_check = self.accelerator.unwrap_model(self.model)
if hasattr(model_to_check, "base_model"): # PEFT wrapper
model_to_check = model_to_check.base_model
if hasattr(model_to_check, "model"):
model_to_check = model_to_check.model
fwd = getattr(model_to_check, "forward", None)
if fwd is not None:
import inspect
params = inspect.signature(fwd).parameters
if "num_items_in_batch" not in params:
self.model_accepts_loss_kwargs = False
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
@@ -383,13 +404,29 @@ class AxolotlTrainer(
# Gemma4 requires mm_token_type_ids during training (even for text-only).
# Inject zeros (= text token type) when not provided by the data collator.
# Use unwrap_model to handle DDP/FSDP wrappers that don't proxy .config.
_unwrapped = self.accelerator.unwrap_model(model)
_model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None)
if (
"mm_token_type_ids" not in inputs
and "input_ids" in inputs
and getattr(getattr(model, "config", None), "model_type", None) == "gemma4"
and _model_type == "gemma4"
):
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
# Gemma4 (and Gemma3): transformers' masking_utils detects packed sequences
# from position_ids, but only when attention_mask is None. When sample
# packing is active the collator provides an all-ones attention_mask that
# prevents this detection — remove it so the model builds the correct
# per-sequence causal masks.
if (
self.args.sample_packing
and _model_type in ("gemma4", "gemma3")
and "attention_mask" in inputs
and "position_ids" in inputs
):
del inputs["attention_mask"]
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
@@ -398,6 +435,23 @@ class AxolotlTrainer(
num_items_in_batch=num_items_in_batch,
)
# Gemma4ForConditionalGeneration computes loss with a manual
# nn.CrossEntropyLoss() that bypasses proper num_items_in_batch
# normalization and does redundant attention_mask filtering.
# Compute loss externally using the standard loss_function instead.
if _model_type == "gemma4" and "labels" in inputs:
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
unwrapped = self.accelerator.unwrap_model(model)
vocab_size = unwrapped.config.get_text_config().vocab_size
loss = unwrapped.loss_function(
logits, labels, vocab_size, num_items_in_batch=num_items_in_batch
)
if return_outputs:
return loss, outputs
return loss
return super().compute_loss(
model,
inputs,
@@ -410,6 +464,21 @@ class AxolotlTrainer(
LOG.info("Running evaluation step...")
return super().evaluate(*args, **kwargs)
@override
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
# Gemma4 requires mm_token_type_ids even during evaluation.
_unwrapped = self.accelerator.unwrap_model(model)
_model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None)
if (
"mm_token_type_ids" not in inputs
and "input_ids" in inputs
and _model_type == "gemma4"
):
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
return super().prediction_step(
model, inputs, prediction_loss_only, ignore_keys=ignore_keys
)
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {}

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"
```
## Usage
@@ -44,6 +44,7 @@ plugins:
- gemma3_text
- gemma3n
- gemma3n_text
- gemma4
- glm
- glm4
- glm4_moe

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`'
'`pip uninstall -y cut-cross-entropy && pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"`'
)

View File

@@ -146,10 +146,6 @@ Gemma 4 (e.g. `google/gemma-4-26B-A4B`) has a unique hybrid MoE architecture:
Because there is no SparseMoeBlock class to patch, Gemma 4 uses a different integration path: we register `"scattermoe"` as a custom implementation in the transformers `ExpertsInterface`, and set `experts_implementation: scattermoe` in the config. The `@use_experts_implementation` decorator on `Gemma4TextExperts` then dispatches to our ScatterMoE kernel automatically. The router is untouched — it runs as-is.
**Important limitations:**
- **Flash Attention 2 is not supported** — Gemma 4 uses `global_head_dim: 512` for full attention layers, which exceeds FA2's maximum head dimension of 256. Use `sdp_attention: true` instead.
- **Multimodal model**: Gemma 4 includes vision and audio encoders. For text-only SFT, use `lora_target_linear_modules` with a regex to restrict LoRA to the text backbone (e.g. `language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj`).
## Limitations
- **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`).

View File

@@ -53,28 +53,6 @@ class KernelsArgs(BaseModel):
return data
@model_validator(mode="before")
@classmethod
def warn_sonicmoe_lora_overhead(cls, data):
if data.get("use_sonicmoe") is True and data.get("adapter") in (
"lora",
"qlora",
):
lora_target = data.get("lora_target_modules") or []
lora_linear = data.get("lora_target_linear_modules") or []
targets = (
lora_target if isinstance(lora_target, list) else [lora_target]
) + (lora_linear if isinstance(lora_linear, list) else [lora_linear])
expert_keywords = ("gate_up_proj", "down_proj", "experts")
if any(kw in t for t in targets for kw in expert_keywords):
LOG.info(
"SonicMoE + LoRA on expert modules uses runtime weight materialization "
"(W_eff = W + scaling*B@A per forward). This has slightly higher overhead "
"than ScatterMoE's fused Triton LoRA kernels but works with any CUTLASS kernel."
)
return data
@model_validator(mode="before")
@classmethod
def disable_mlp_kernel(cls, data):

View File

@@ -60,49 +60,14 @@ def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Convert peft LoRA weights to scattermoe layout (with A<->B swap).
"""Convert peft LoRA weights to scattermoe layout.
peft operates on the parameter in its native storage layout ``[E, dim1, dim2]``
where ``in_features=dim1, out_features=dim2``. ScatterMoE transposes the
parameter (``W = param.transpose(2, 1)``) giving ``[E, dim2, dim1]`` with
``K=dim2, N=dim1``. Because of this transposition, peft's A and B roles
are swapped relative to scattermoe's convention.
peft gives:
lora_A ``[r*E, dim1]``, lora_B ``[dim2, r*E]``
scattermoe needs:
lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]``
This function swaps A<->B and converts B from rank-major to expert-major.
Uses vectorized tensor operations (no Python loop over experts).
Works for **both** gate_up_proj and down_proj since the transposition
issue is the same for any parameter.
peft >=0.19.1 assigns in/out features for 3D params such that
A and B already align with scattermoe's convention (no A<->B swap).
Only B needs rank-major → expert-major layout conversion.
"""
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
dim1 = peft_A.shape[1] # peft in_features -> scattermoe N
dim2 = peft_B_em.shape[0] # peft out_features -> scattermoe K
# smoe_A: per expert, transpose B_e [dim2, r] -> [r, dim2]
# [dim2, E*r] -> [dim2, E, r] -> [E, r, dim2] -> [E*r, dim2]
smoe_A = (
peft_B_em.reshape(dim2, num_experts, rank)
.permute(1, 2, 0)
.contiguous()
.reshape(rank * num_experts, dim2)
)
# smoe_B: per expert, transpose A_e [r, dim1] -> [dim1, r]
# [E*r, dim1] -> [E, r, dim1] -> [dim1, E, r] -> [dim1, E*r]
smoe_B = (
peft_A.reshape(num_experts, rank, dim1)
.permute(2, 0, 1)
.contiguous()
.reshape(dim1, num_experts * rank)
)
smoe_A = peft_A
smoe_B = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
return smoe_A, smoe_B

View File

@@ -222,6 +222,56 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
swiglu=cfg.liger_glu_activation,
)
elif cfg.model_config_type in ("gemma4", "gemma4_text"):
# Gemma4: offset=0 (NOT 1 like Gemma3), in_place=False required for
# gradient checkpointing compatibility, RoPE incompatible (separate q/k).
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from transformers.models.gemma4 import modeling_gemma4
if cfg.liger_rms_norm:
_OrigGemma4RMSNorm = modeling_gemma4.Gemma4RMSNorm
class _LigerGemma4RMSNorm(LigerRMSNorm):
"""LigerRMSNorm for Gemma4 with in_place=False and with_scale support."""
def __new__(cls, dim, eps=1e-6, with_scale=True):
if not with_scale:
return _OrigGemma4RMSNorm(dim, eps, with_scale=False)
return super().__new__(cls)
def __init__(self, dim, eps=1e-6, with_scale=True):
if not with_scale:
return
# offset=0.0 (standard), in_place=False (gradient checkpointing safe)
super().__init__(
dim, eps, offset=0.0, casting_mode="llama", in_place=False
)
modeling_gemma4.Gemma4RMSNorm = _LigerGemma4RMSNorm
if cfg.liger_glu_activation:
class _LigerGemma4MLP(LigerGEGLUMLP):
def __init__(self, config, layer_idx=None):
super().__init__(config)
modeling_gemma4.Gemma4TextMLP = _LigerGemma4MLP
if cfg.liger_rope:
LOG.warning(
"Liger RoPE is not compatible with Gemma4 (separate q/k application). Skipping."
)
if cfg.liger_layer_norm:
modeling_gemma4.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
modeling_gemma4.nn.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
LOG.warning(
"Liger fused linear cross entropy is not compatible with Gemma4. Skipping."
)
LOG.info(
f"Applied Liger kernels for gemma4: "
f"rms_norm={cfg.liger_rms_norm}, glu={cfg.liger_glu_activation}, "
f"rope=False (incompatible), layer_norm={cfg.liger_layer_norm}"
)
elif cfg.liger_fused_linear_cross_entropy:
try:
from .models.base import patch_lce_forward

View File

@@ -0,0 +1,529 @@
"""
Fused RMSNorm + RoPE Triton kernel for Gemma 4.
Fuses three operations into one kernel launch:
1. RMSNorm: x_norm = (x / sqrt(mean(x^2) + eps)) * weight
2. RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
3. (optional) RMSNorm without scale (for v_norm)
This eliminates two intermediate tensor materializations per Q/K path;
churn from rotate_half / apply_rotary_pos_emb.
Shapes:
X: (rows, head_dim) — flattened from (batch, seq_len, num_heads, head_dim)
W: (head_dim,) — RMSNorm weight (None for with_scale=False)
cos: (rows, head_dim) — flattened from (batch, seq_len, 1, head_dim) after broadcast
sin: (rows, head_dim) — same as cos
"""
import math
import operator
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import (
calculate_settings,
compare_version,
ensure_contiguous,
torch_to_triton_dtype,
)
from liger_kernel.utils import is_npu_available
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
from triton.language.extra.libdevice import rsqrt
except ModuleNotFoundError:
from triton.language.extra.cuda.libdevice import rsqrt
else:
from triton.language.math import rsqrt
@triton.jit
def _rms_norm_rope_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
W_ptr,
COS_ptr,
COS_row_stride,
SIN_ptr,
SIN_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_cols,
n_heads,
eps,
HAS_WEIGHT: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Fused forward:
x_norm = x / rms(x) [* weight] (RMSNorm)
y = x_norm * cos + rotate_half(x_norm) * sin (RoPE)
rotate_half swaps first/second halves and negates the first:
rotate_half([a, b]) = [-b, a]
cos/sin are indexed by row_idx // n_heads to handle per-head broadcast
(cos/sin have shape (B*S, D) while X has shape (B*S*H, D)).
"""
row_idx = tl.program_id(0).to(tl.int64)
# cos/sin row: divide by n_heads since cos/sin are (B*S, D)
cs_row_idx = row_idx // n_heads
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
half_dim = n_cols // 2
# Load input row
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
X_dtype = X_row.dtype
X_fp32 = X_row.to(tl.float32)
# RMSNorm: compute 1/rms
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
rstd = rsqrt(mean_sq + eps)
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
# Normalize
X_norm = X_fp32 * rstd
# Apply weight if present (with_scale=True)
if HAS_WEIGHT:
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
X_norm = X_norm * W_row
# RoPE: load cos/sin (broadcast across heads)
cos_row = tl.load(
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
sin_row = tl.load(
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
# rotate_half: for col < half_dim, take -X_norm[col + half_dim]
# for col >= half_dim, take X_norm[col - half_dim]
rot_offsets = tl.where(
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
)
rot_mask = rot_offsets < n_cols
X_rot = tl.load(
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_mask & mask, other=0
).to(tl.float32)
# Re-normalize the rotated values
X_rot_norm = X_rot * rstd
if HAS_WEIGHT:
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_mask & mask, other=0).to(
tl.float32
)
X_rot_norm = X_rot_norm * W_rot
# Negate the first half (rotate_half negates x2, which becomes the first half)
sign = tl.where(col_offsets < half_dim, -1.0, 1.0)
X_rot_norm = X_rot_norm * sign
# Final RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
Y_row = X_norm * cos_row + X_rot_norm * sin_row
tl.store(
Y_ptr + row_idx * Y_row_stride + col_offsets,
Y_row.to(X_dtype),
mask=mask,
)
@triton.jit
def _rms_norm_rope_backward_kernel(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
W_ptr,
COS_ptr,
COS_row_stride,
SIN_ptr,
SIN_row_stride,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_rows,
n_cols,
n_heads,
rows_per_program,
HAS_WEIGHT: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Backward for Y = RoPE(RMSNorm(X, W))
cos/sin indexed by row_idx // n_heads for per-head broadcast.
"""
row_block_id = tl.program_id(0).to(tl.int64)
row_start = row_block_id * rows_per_program
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
half_dim = n_cols // 2
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
if HAS_WEIGHT:
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
for row_idx in range(row_start, row_end):
cs_row_idx = row_idx // n_heads
dY_row = tl.load(
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
X_row = tl.load(
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
cos_row = tl.load(
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
# dN = dY * cos + rotate_half^T(dY * sin)
# rotate_half^T([a, b]) = [b, -a] (adjoint of rotate_half)
#
# Compute rotate_half_transpose(dY * sin) by loading dY and sin at
# rotated offsets directly: dY[rot] * sin[rot] * adj_sign
# This is equivalent to rotating (dY * sin) because the rotation
# just permutes which elements are multiplied.
rot_offsets = tl.where(
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
)
rot_mask = rot_offsets < n_cols
dY_rot = tl.load(
dY_ptr + row_idx * dY_row_stride + rot_offsets,
mask=rot_mask & mask,
other=0,
).to(tl.float32)
sin_rot = tl.load(
SIN_ptr + cs_row_idx * SIN_row_stride + rot_offsets,
mask=rot_mask & mask,
other=0,
).to(tl.float32)
adj_sign = tl.where(col_offsets < half_dim, 1.0, -1.0)
dN = dY_row * cos_row + dY_rot * sin_rot * adj_sign
# Pre-weight normalized: n = rstd * x
n = X_row * rstd
if HAS_WEIGHT:
dW_acc += dN * n
dm = dN * W_row
else:
dm = dN
# RMSNorm backward: dX = rstd * (dm - (1/n_cols) * rstd^2 * dot(dm, X) * X)
dot_dm_x = tl.sum(dm * X_row, axis=0)
dX_row = rstd * (dm - (1.0 / n_cols) * rstd * rstd * dot_dm_x * X_row)
tl.store(
dX_ptr + row_idx * dX_row_stride + col_offsets,
dX_row.to(X_dtype),
mask=mask,
)
if HAS_WEIGHT:
tl.store(
dW_ptr + row_block_id * dW_row_stride + col_offsets,
dW_acc,
mask=mask,
)
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
"""
Args:
X: (B*S*H, head_dim) — contiguous, flattened from (B, S, H, D)
W: (head_dim,) or None — RMSNorm weight
cos: (B*S, head_dim) — position embeddings (broadcast across heads)
sin: (B*S, head_dim) — position embeddings (broadcast across heads)
eps: float
n_heads: int — number of attention heads (for cos/sin indexing)
Returns:
Y, X_saved, RSTD, BLOCK_SIZE, num_warps
"""
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
has_weight = W is not None
Y = torch.empty_like(X)
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
_rms_norm_rope_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
X,
X.stride(0),
W if has_weight else X, # dummy pointer when no weight
cos,
cos.stride(0),
sin,
sin.stride(0),
RSTD,
RSTD.stride(0),
n_cols,
n_heads,
eps,
HAS_WEIGHT=has_weight,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return Y, X, RSTD, BLOCK_SIZE, num_warps
def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_warps):
n_rows, n_cols = dY.shape
has_weight = W is not None
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
rows_per_program = math.ceil(n_rows / sm_count)
dX = torch.empty_like(X)
if has_weight:
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=X.device)
else:
_dW = torch.empty((1, n_cols), dtype=torch.float32, device=X.device)
_rms_norm_rope_backward_kernel[(sm_count,)](
dY,
dY.stride(0),
dX,
dX.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
W if has_weight else X, # dummy
cos,
cos.stride(0),
sin,
sin.stride(0),
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0),
n_rows,
n_cols,
n_heads,
rows_per_program,
HAS_WEIGHT=has_weight,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
dW = _dW.sum(dim=0).to(W.dtype) if has_weight else None
return dX, dW
class FusedRMSNormRoPEFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, X, W, cos, sin, eps, n_heads):
"""
X: (B*S*H, head_dim)
W: (head_dim,) or None
cos: (B*S, head_dim) — broadcast across heads
sin: (B*S, head_dim) — broadcast across heads
n_heads: int
"""
Y, X_saved, RSTD, BLOCK_SIZE, num_warps = rms_norm_rope_forward(
X,
W,
cos,
sin,
eps,
n_heads,
)
ctx.eps = eps
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.n_heads = n_heads
ctx.has_weight = W is not None
ctx.save_for_backward(X_saved, W, cos, sin, RSTD)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, W, cos, sin, RSTD = ctx.saved_tensors
dX, dW = rms_norm_rope_backward(
dY,
X,
W,
cos,
sin,
RSTD,
ctx.n_heads,
ctx.BLOCK_SIZE,
ctx.num_warps,
)
return dX, dW, None, None, None, None
def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
"""
Apply fused RMSNorm + RoPE.
Args:
x: (batch, seq_len, num_heads, head_dim) — after projection + view
weight: (head_dim,) — RMSNorm weight, or None for no-scale norm
cos: (batch, seq_len, head_dim) — from RotaryEmbedding
sin: (batch, seq_len, head_dim) — from RotaryEmbedding
eps: float — RMSNorm epsilon
Returns:
y: (batch, seq_len, num_heads, head_dim) — normalized + rotated
"""
shape = x.shape # (B, S, H, D)
B, S, H, D = shape
# Flatten to 2D: (B*S*H, D)
x_flat = x.reshape(-1, D).contiguous()
# Flatten cos/sin to (B*S, D) — the kernel will handle per-head broadcast
# by dividing the row_idx by H to get the cos/sin row
cos_flat = cos.reshape(B * S, D).contiguous()
sin_flat = sin.reshape(B * S, D).contiguous()
y_flat = FusedRMSNormRoPEFunction.apply(x_flat, weight, cos_flat, sin_flat, eps, H)
return y_flat.view(shape)
@triton.jit
def _rms_norm_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_cols,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""RMSNorm without scale weight: y = x / rms(x)"""
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
X_dtype = X_row.dtype
X_fp32 = X_row.to(tl.float32)
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
rstd = rsqrt(mean_sq + eps)
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
Y_row = X_fp32 * rstd
tl.store(Y_ptr + row_idx * Y_row_stride + col_offsets, Y_row.to(X_dtype), mask=mask)
@triton.jit
def _rms_norm_noscale_backward_kernel(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
RSTD_ptr,
RSTD_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""Backward for y = x * rstd (no weight)."""
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
dY_row = tl.load(
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
X_row = tl.load(
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
dot_dy_x = tl.sum(dY_row * X_row, axis=0)
dX_row = rstd * (dY_row - (1.0 / n_cols) * rstd * rstd * dot_dy_x * X_row)
tl.store(
dX_ptr + row_idx * dX_row_stride + col_offsets, dX_row.to(X_dtype), mask=mask
)
class FusedRMSNormNoScaleFunction(torch.autograd.Function):
"""RMSNorm without learnable scale — used for Gemma4's v_norm."""
@staticmethod
@ensure_contiguous
def forward(ctx, X, eps):
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Y = torch.empty_like(X)
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
_rms_norm_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
X,
X.stride(0),
RSTD,
RSTD.stride(0),
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.save_for_backward(X, RSTD)
ctx.n_cols = n_cols
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, RSTD = ctx.saved_tensors
n_rows = X.shape[0]
dX = torch.empty_like(X)
_rms_norm_noscale_backward_kernel[(n_rows,)](
dY,
dY.stride(0),
dX,
dX.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
RSTD,
RSTD.stride(0),
ctx.n_cols,
BLOCK_SIZE=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
)
return dX, None
def fused_rms_norm_noscale(x, eps=1e-6):
"""
RMSNorm without scale for v_norm.
Args:
x: (batch, seq_len, num_heads, head_dim)
Returns:
y: same shape, normalized
"""
shape = x.shape
x_flat = x.reshape(-1, shape[-1])
y_flat = FusedRMSNormNoScaleFunction.apply(x_flat, eps)
return y_flat.view(shape)

View File

@@ -1297,6 +1297,339 @@ def apply_lora_qkv(
return Q, K, V
class LoRA_QK(torch.autograd.Function):
"""Optimized LoRA QK implementation for models where v_proj is None.
Used by models like Gemma4 with attention_k_eq_v=True, where key states are
reused as value states. Only Q and K projections are fused; the caller
returns K a second time as V so that autograd accumulates key+value gradients
into a single dK.
Supports bias, dropout, and DoRA (Weight-Decomposed Low-Rank Adaptation).
"""
@staticmethod
@torch_amp_custom_fwd
def forward(
ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor,
X_drop: torch.Tensor | None,
# Q params
q_weight: torch.Tensor,
q_bias: torch.Tensor | None,
q_quant: QuantState | None,
q_A: torch.Tensor | None,
q_B: torch.Tensor | None,
q_scale: float,
q_lora_bias: torch.Tensor | None,
q_magnitude: torch.Tensor | None,
# K params
k_weight: torch.Tensor,
k_bias: torch.Tensor | None,
k_quant: QuantState | None,
k_A: torch.Tensor | None,
k_B: torch.Tensor | None,
k_scale: float,
k_lora_bias: torch.Tensor | None,
k_magnitude: torch.Tensor | None,
# Flags
inplace: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
has_dropout = X_drop is not None
has_dora = q_magnitude is not None
if has_dora:
dtype = X.dtype
X_lora = X_drop if has_dropout else X
# Compute Q with DoRA
Q_base = matmul_lora(X, q_weight, None, q_quant, None, None, None)
Q_lora = _lora_only(X_lora, q_A, q_B, q_scale, q_lora_bias, dtype)
q_mag_scale = _compute_dora_scale(
q_weight, q_quant, q_A, q_B, q_scale, q_magnitude, dtype
)
Q = q_mag_scale.unsqueeze(0) * (Q_base + Q_lora)
if q_bias is not None:
Q = Q + q_bias
# Compute K with DoRA
K_base = matmul_lora(X, k_weight, None, k_quant, None, None, None)
K_lora = _lora_only(X_lora, k_A, k_B, k_scale, k_lora_bias, dtype)
k_mag_scale = _compute_dora_scale(
k_weight, k_quant, k_A, k_B, k_scale, k_magnitude, dtype
)
K = k_mag_scale.unsqueeze(0) * (K_base + K_lora)
if k_bias is not None:
K = K + k_bias
Q_combined = Q_base + Q_lora
K_combined = K_base + K_lora
ctx.save_for_backward(
X,
X_drop if has_dropout else X,
q_A.to(dtype) if q_A is not None else q_A,
q_B.to(dtype) if q_B is not None else q_B,
k_A.to(dtype) if k_A is not None else k_A,
k_B.to(dtype) if k_B is not None else k_B,
q_magnitude,
k_magnitude,
q_mag_scale,
k_mag_scale,
Q_combined,
K_combined,
q_lora_bias,
k_lora_bias,
)
else:
# Standard LoRA (with optional dropout and bias)
Q = matmul_lora(
X,
q_weight,
q_bias,
q_quant,
q_A,
q_B,
q_scale,
X_drop=X_drop,
lora_bias=q_lora_bias,
)
K = matmul_lora(
X,
k_weight,
k_bias,
k_quant,
k_A,
k_B,
k_scale,
X_drop=X_drop,
lora_bias=k_lora_bias,
)
dtype = X.dtype
ctx.save_for_backward(
X,
X_drop if has_dropout else X,
q_A.to(dtype) if q_A is not None else q_A,
q_B.to(dtype) if q_B is not None else q_B,
k_A.to(dtype) if k_A is not None else k_A,
k_B.to(dtype) if k_B is not None else k_B,
q_lora_bias,
k_lora_bias,
)
ctx.scales = (q_scale, k_scale)
ctx.quants = (q_quant, k_quant)
ctx.weights = (q_weight, k_weight)
ctx.inplace = inplace
ctx.has_dropout = has_dropout
ctx.has_dora = has_dora
return Q, K
@staticmethod
@torch_amp_custom_bwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
q_grad: torch.Tensor,
k_grad: torch.Tensor,
):
q_weight, k_weight = ctx.weights
q_quant, k_quant = ctx.quants
q_scale, k_scale = ctx.scales
has_dropout = ctx.has_dropout
has_dora = ctx.has_dora
if has_dora:
(
X,
X_lora,
A_q,
B_q,
A_k,
B_k,
q_magnitude,
k_magnitude,
q_mag_scale,
k_mag_scale,
Q_combined,
K_combined,
q_lora_bias,
k_lora_bias,
) = ctx.saved_tensors
else:
(
X,
X_lora,
A_q,
B_q,
A_k,
B_k,
q_lora_bias,
k_lora_bias,
) = ctx.saved_tensors
q_magnitude = k_magnitude = None
q_mag_scale = k_mag_scale = None
Q_combined = K_combined = None
batch, seq_len = X.shape[:2]
q_grad = q_grad.view(-1, q_grad.shape[-1])
k_grad = k_grad.reshape(-1, k_grad.shape[-1])
X = X.view(-1, X.shape[-1])
X_lora = X_lora.view(-1, X_lora.shape[-1])
d_q_mag = d_k_mag = None
d_q_lora_bias = d_k_lora_bias = None
if has_dora:
Q_combined = Q_combined.view(-1, Q_combined.shape[-1])
K_combined = K_combined.view(-1, K_combined.shape[-1])
d_q_mag = (q_grad * Q_combined).sum(dim=0) * q_mag_scale / q_magnitude
d_k_mag = (k_grad * K_combined).sum(dim=0) * k_mag_scale / k_magnitude
q_grad = q_grad * q_mag_scale.unsqueeze(0)
k_grad = k_grad * k_mag_scale.unsqueeze(0)
# LoRA bias gradients
if q_lora_bias is not None:
d_q_lora_bias = q_scale * q_grad.sum(dim=0)
if k_lora_bias is not None:
d_k_lora_bias = k_scale * k_grad.sum(dim=0)
X_lora_t = X_lora.t()
d_A_q = d_B_q = d_A_k = d_B_k = None
grad_B_q = grad_B_k = None
if A_q is not None and B_q is not None:
grad_B_q = q_grad @ B_q
d_A_q = torch.empty_like(A_q.t())
d_B_q = torch.empty_like(B_q.t())
d_A_q.addmm_(X_lora_t, grad_B_q, alpha=q_scale, beta=0)
d_B_q.addmm_(A_q @ X_lora_t, q_grad, alpha=q_scale, beta=0)
if A_k is not None and B_k is not None:
grad_B_k = k_grad @ B_k
d_A_k = torch.empty_like(A_k.t())
d_B_k = torch.empty_like(B_k.t())
d_A_k.addmm_(X_lora_t, grad_B_k, alpha=k_scale, beta=0)
d_B_k.addmm_(A_k @ X_lora_t, k_grad, alpha=k_scale, beta=0)
# Base path input gradient
out_buffer = X if ctx.inplace else None
q_weight_t = dequantize(q_weight, q_quant)
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
del q_weight_t
k_weight_t = dequantize(k_weight, k_quant)
grad_X.addmm_(k_grad, k_weight_t)
del k_weight_t
# LoRA path input gradient
if has_dropout:
grad_X_drop = torch.zeros_like(X_lora)
if grad_B_q is not None:
grad_X_drop.addmm_(grad_B_q, A_q, alpha=q_scale)
if grad_B_k is not None:
grad_X_drop.addmm_(grad_B_k, A_k, alpha=k_scale)
else:
grad_X_drop = None
if grad_B_q is not None:
grad_X.addmm_(grad_B_q, A_q, alpha=q_scale)
if grad_B_k is not None:
grad_X.addmm_(grad_B_k, A_k, alpha=k_scale)
if d_A_q is not None:
d_A_q = d_A_q.t()
d_B_q = d_B_q.t() # type: ignore[union-attr]
if d_A_k is not None:
d_A_k = d_A_k.t()
d_B_k = d_B_k.t() # type: ignore[union-attr]
grad_X = grad_X.view(batch, seq_len, -1)
if grad_X_drop is not None:
grad_X_drop = grad_X_drop.view(batch, seq_len, -1)
# Return gradients for all forward inputs:
# X, X_drop,
# q: weight, bias, quant, A, B, scale, lora_bias, magnitude
# k: weight, bias, quant, A, B, scale, lora_bias, magnitude
# inplace
return (
grad_X,
grad_X_drop,
# Q
None,
None,
None,
d_A_q,
d_B_q,
None,
d_q_lora_bias,
d_q_mag,
# K
None,
None,
None,
d_A_k,
d_B_k,
None,
d_k_lora_bias,
d_k_mag,
# inplace
None,
)
def apply_lora_qk(
self, X: torch.Tensor, inplace: bool = True
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Applies LoRA to compute Query and Key projections for models where v_proj is None.
When v_proj is None (e.g. Gemma4 attention_k_eq_v), key states are reused as
value states. Returns (Q, K, K) — the caller's patched forward will use K as V.
Because K is returned twice, autograd accumulates gradients from both the key and
value paths into dK before calling LoRA_QK.backward.
Supports bias, dropout, and DoRA.
"""
QW, Qb, QW_quant, QA, QB, QS, Qlb, Qdrop, Qmag = get_lora_parameters(self.q_proj)
KW, Kb, KW_quant, KA, KB, KS, Klb, Kdrop, Kmag = get_lora_parameters(self.k_proj)
# Apply dropout outside autograd.Function (shared mask for Q, K)
X_drop = _apply_dropout(Qdrop, X, self.training)
Q, K = LoRA_QK.apply(
X,
X_drop,
# Q
QW,
Qb,
QW_quant,
QA,
QB,
QS,
Qlb,
Qmag,
# K
KW,
Kb,
KW_quant,
KA,
KB,
KS,
Klb,
Kmag,
# Flags
inplace,
)
return Q, K, K
class LoRA_O(torch.autograd.Function):
"""Optimized LoRA implementation for output projection.

View File

@@ -67,12 +67,165 @@ def find_all_linear_names(model):
return list(lora_module_names)
def _patch_peft_clippable_linear():
"""Patch PEFT to handle Gemma4ClippableLinear which wraps nn.Linear.
Gemma4's vision tower uses ClippableLinear (a thin wrapper around nn.Linear
that clips activations). PEFT doesn't recognise it as a supported layer type,
so we redirect LoRA injection to the inner ``.linear`` child instead.
"""
try:
from transformers.models.gemma4.modeling_gemma4 import (
Gemma4ClippableLinear as _cls,
)
except ImportError:
return
from peft.tuners.lora.model import LoraModel
if getattr(LoraModel, "_axolotl_clippable_patched", False):
return
_orig = LoraModel._create_and_replace
def _patched(
self,
peft_config,
adapter_name,
target,
target_name,
parent,
current_key=None,
**kw,
):
if isinstance(target, _cls):
# Redirect to the inner nn.Linear so PEFT can wrap it normally.
return _orig(
self,
peft_config,
adapter_name,
target.linear,
"linear",
target,
current_key=current_key,
**kw,
)
return _orig(
self,
peft_config,
adapter_name,
target,
target_name,
parent,
current_key=current_key,
**kw,
)
LoraModel._create_and_replace = _patched
LoraModel._axolotl_clippable_patched = True
def _peft_will_auto_convert_target_params(model, lora_config) -> bool:
"""Check whether PEFT will auto-populate target_parameters for this model.
PEFT 0.19's ``convert_peft_config_for_transformers`` rewrites old MoE
``target_modules`` (e.g. ``w1``/``w2``/``w3`` on Mixtral) into
``target_parameters`` (``gate_up_proj``/``down_proj``) because
transformers v5 fused those expert linears into 3D ``nn.Parameter``
tensors. PEFT wraps the resulting 3D params with ``ParamWrapper``,
which rejects ``lora_dropout != 0``. This probe runs the conversion on
a copy of the config so we can detect the situation before
``get_peft_model`` blows up.
"""
if getattr(lora_config, "target_parameters", None):
return False
try:
from peft.utils.transformers_weight_conversion import (
convert_peft_config_for_transformers,
get_model_conversion_mapping,
)
except ImportError:
return False
import copy
probe_cfg = copy.deepcopy(lora_config)
try:
convert_peft_config_for_transformers(
probe_cfg,
model=model,
conversions=get_model_conversion_mapping(model),
)
except Exception: # pylint: disable=broad-except
return False
return bool(getattr(probe_cfg, "target_parameters", None))
def _patch_peft_param_wrapper_dropout():
"""Let PEFT's ``ParamWrapper`` silently accept ``lora_dropout != 0``.
``ParamWrapper`` wraps 3D expert ``nn.Parameter`` tensors and rejects
non-zero dropout because dropout can't be factored out of
``lora_B(lora_A(dropout(x)))`` when the inner op is an expert-indexed
matmul. For mixed configs (attention + MoE experts) this is too
aggressive — the non-expert ``Linear`` LoRA layers *can* apply dropout
and that's usually what the user intended. We pass a copy of the
``LoraConfig`` with ``lora_dropout=0`` only to ``ParamWrapper.__init__``
so it builds with ``nn.Identity`` for its internal dropout slot while
every other layer type still receives the real dropout value.
"""
from peft.tuners.lora.layer import ParamWrapper
if getattr(ParamWrapper, "_axolotl_dropout_patched", False):
return
_orig_init = ParamWrapper.__init__
def _patched_init(
self,
base_layer,
adapter_name,
parameter_name,
config,
*args,
**kwargs,
):
if getattr(config, "lora_dropout", 0):
import copy as _copy
patched_config = _copy.copy(config)
patched_config.lora_dropout = 0.0
return _orig_init(
self,
base_layer,
adapter_name,
parameter_name,
patched_config,
*args,
**kwargs,
)
return _orig_init(
self,
base_layer,
adapter_name,
parameter_name,
config,
*args,
**kwargs,
)
ParamWrapper.__init__ = _patched_init
ParamWrapper._axolotl_dropout_patched = True
def load_lora(
model: PreTrainedModel,
cfg: DictDefault,
inference: bool = False,
config_only: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
_patch_peft_clippable_linear()
lora_target_modules = cfg.lora_target_modules or []
lora_target_parameters = cfg.lora_target_parameters or []
@@ -124,6 +277,7 @@ def load_lora(
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
exclude_modules=getattr(cfg, "lora_exclude_modules", None) or None,
bias="none",
task_type=task_type,
**lora_config_kwargs,
@@ -132,6 +286,20 @@ def load_lora(
if config_only:
return None, lora_config
if getattr(
lora_config, "lora_dropout", 0
) and _peft_will_auto_convert_target_params(model, lora_config):
LOG.warning(
"lora_dropout=%s requested but PEFT will wrap this model's fused "
"MoE expert parameters with ParamWrapper, which cannot apply "
"dropout (the 3D einsum can't factor dropout out of "
"lora_B(lora_A(dropout(x)))). Dropout will still be applied to "
"non-expert LoRA layers (e.g. attention), and expert LoRA layers "
"will use nn.Identity for the dropout slot.",
lora_config.lora_dropout,
)
_patch_peft_param_wrapper_dropout()
rank = int(os.environ.get("LOCAL_RANK", 0))
if (

View File

@@ -547,6 +547,16 @@ class ModelLoader:
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
if self.cfg.model_quantization_config == "FineGrainedFP8Config":
from transformers import FineGrainedFP8Config
fp8_kwargs = {}
if self.cfg.model_quantization_config_kwargs:
fp8_kwargs = self.cfg.model_quantization_config_kwargs
self.model_kwargs["quantization_config"] = FineGrainedFP8Config(
**fp8_kwargs
)
if self.cfg.gptq:
if not hasattr(self.model_config, "quantization_config"):
LOG.warning(
@@ -624,7 +634,14 @@ class ModelLoader:
def _set_attention_config(self):
"""Sample packing uses custom FA2 patch"""
if self.cfg.attn_implementation:
if self.cfg.gemma4_hybrid_attn_impl:
# Load model with flash_attention_2 for sliding window layers;
# global layers will be patched to sdpa post-load.
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = "flash_attention_2"
# Set flash_attention so multipack/sample_packing patches activate
self.cfg.flash_attention = True
elif self.cfg.attn_implementation:
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
elif self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"

View File

@@ -156,15 +156,81 @@ class PatchManager:
# which would clobber any earlier fix.
self._fix_nemotron_h_conversion_mapping()
self._apply_gemma_hybrid_attention(model)
self._finalize_moe_expert_quantization(model)
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
self._apply_llama_flash_attn_patches(model)
self._apply_unsloth_patches(model)
self._apply_lora_kernel_patch(model)
self._apply_scaling_softmax_patch(model)
def _apply_gemma_hybrid_attention(self, model: PreTrainedModel):
"""Apply hybrid attention: FA2 for sliding window layers, SDPA for global layers.
Gemma 4 has global (full_attention) layers with head_dim=512
which exceeds flash attention's supported size. This patch loads the model
with flash_attention_2 for the sliding window layers (head_dim=256), then
gives each global layer a shallow-copied config with _attn_implementation="sdpa".
"""
if not self.cfg.gemma4_hybrid_attn_impl:
return
import copy
# Navigate to the module that has 'layers' - varies by model structure:
# Gemma4ForConditionalGeneration -> .model (Gemma4Model) -> .language_model (Gemma4TextModel) -> .layers
# Gemma4ForCausalLM -> .model (Gemma4TextModel) -> .layers
layers = None
config_source = None
for candidate in [model, getattr(model, "model", None)]:
if candidate is None:
continue
# Check direct layers
if hasattr(candidate, "layers"):
layers = candidate.layers
config_source = candidate
break
# Check language_model.layers (multimodal wrapper)
lang_model = getattr(candidate, "language_model", None)
if lang_model is not None and hasattr(lang_model, "layers"):
layers = lang_model.layers
config_source = lang_model
break
if layers is None:
LOG.warning(
"gemma4_hybrid_attn_impl: could not find decoder layers in model, skipping"
)
return
config = getattr(config_source, "config", self.model_config)
layer_types = getattr(config, "layer_types", None)
if layer_types is None:
LOG.warning(
"gemma4_hybrid_attn_impl: model config has no 'layer_types', skipping. "
"This feature requires a model with mixed sliding/global attention layers."
)
return
patched_count = 0
for layer_idx, layer in enumerate(layers):
if layer_types[layer_idx] != "sliding_attention":
# Global / full_attention layer - use SDPA instead of FA2
attn_module = getattr(layer, "self_attn", None)
if attn_module is not None and hasattr(attn_module, "config"):
sdpa_config = copy.copy(attn_module.config)
sdpa_config._attn_implementation = "sdpa"
attn_module.config = sdpa_config
patched_count += 1
LOG.info(
"gemma4_hybrid_attn_impl: patched %d global layers to use SDPA "
"(remaining %d sliding layers use flash_attention_2)",
patched_count,
len(layers) - patched_count,
)
def _apply_flash_attention_patches(self):
"""Apply patches related to Flash Attention."""
if self.cfg.xformers_attention and self.cfg.sample_packing:
@@ -324,6 +390,22 @@ class PatchManager:
patch_qwen3_5_vlm_flash_attention()
if self.cfg.model_config_type in ("gemma4", "gemma4_text"):
from axolotl.monkeypatch.models.gemma4.fused_attn import (
patch_gemma4_fused_attn,
)
# Shared-KV side channel when activation checkpointing (PR #3611).
fsdp_cfg = self.cfg.fsdp_config
needs_shared_kv_workaround = (not self.inference) and bool(
self.cfg.gradient_checkpointing
or self.cfg.activation_offloading
or (fsdp_cfg is not None and fsdp_cfg.activation_checkpointing)
)
patch_gemma4_fused_attn(
install_shared_kv_workaround=needs_shared_kv_workaround
)
@staticmethod
def _fix_nemotron_h_conversion_mapping():
"""Remove the spurious embedding→embeddings WeightRenaming from the
@@ -600,24 +682,10 @@ class PatchManager:
)
patch_fa_llama_cross_entropy()
elif self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.flash_attn_rms_norm and self.has_flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import patch_llama_rms_norm
patch_llama_rms_norm()
elif self.cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
patch_unsloth_layernorm()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
def _patch_llama_flash_attention(self):
"""Apply Flash Attention patches for LLaMA models."""
@@ -684,23 +752,6 @@ class PatchManager:
LOG.info("Patching with SwiGLU...")
replace_llama_mlp_with_swiglu(model)
def _apply_unsloth_patches(self, model):
"""Apply unsloth optimization patches."""
if self.cfg.unsloth_lora_mlp:
from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
integrate_lora_mlp_patch(peft_model=model)
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import integrate_lora_patch
integrate_lora_patch(peft_model=model, cfg=self.cfg)
if self.cfg.unsloth_rope:
from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings
integrate_rope_embeddings()
def _apply_lora_kernel_patch(self, model):
"""Apply LoRA kernel patches."""
if (

View File

@@ -221,14 +221,6 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
# Generic fallback: if tokenizer still has no pad_token, use eos_token
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
LOG.warning(
"Tokenizer does not have a pad_token, falling back to eos_token: %s",
tokenizer.eos_token,
)
additional_special_tokens = None
if cfg.special_tokens:
special_tokens = cfg.special_tokens.to_dict()
@@ -303,6 +295,14 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
{"additional_special_tokens": additional_special_tokens}
)
# Generic fallback: if tokenizer still has no pad_token, use eos_token
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
LOG.warning(
"Tokenizer does not have a pad_token, falling back to eos_token: %s",
tokenizer.eos_token,
)
if is_main_process():
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")

View File

@@ -60,6 +60,13 @@ def fsdp2_load_full_state_dict(
sharded_meta_param.placements,
src_data_rank=0,
)
# Clone the local shard to allow full_tensor to be freed.
if (
sharded_param._local_tensor.untyped_storage().size()
> sharded_param._local_tensor.nelement()
* sharded_param._local_tensor.element_size()
):
sharded_param = sharded_param.clone()
else:
# Non-sharded parameters
if _accelerator.is_main_process:

View File

@@ -86,12 +86,19 @@ def patch_flash_attn_4(model_config=None):
if getattr(fa_utils._lazy_imports, "_axolotl_patched", False):
return
try:
# flash-attn-4>=4.0.0b7
from flash_attn.cute import flash_attn_with_kvcache
except ImportError:
flash_attn_with_kvcache = None
def _patched_lazy_imports(
implementation, attention_wrapper=None, allow_all_kernels=False
):
return (
flash_attn_func,
flash_attn_varlen_func,
flash_attn_with_kvcache,
fa_utils._pad_input,
fa_utils._unpad_input,
)

View File

@@ -16,6 +16,7 @@ from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
apply_lora_o,
apply_lora_qk,
apply_lora_qkv,
)
from axolotl.monkeypatch.utils import detab_code
@@ -111,6 +112,47 @@ QKV_PATCHES = [
else:
key_states = key_states.view(hidden_shape)
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
""".lstrip("\n"),
),
# Gemma4 (transformers >= 5.6): shared_kv_states parameter replaces
# past_key_values.shared_layers, and v_norm added after k_norm.
(
"""
query_states = self.q_proj(hidden_states).view(hidden_shape)
query_states = self.q_norm(query_states)
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
query_states = query_states.transpose(1, 2)
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
# We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
# once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
if self.is_kv_shared_layer:
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
# Device of past layer may be different from current one
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states
""".lstrip("\n"),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape)
query_states = self.q_norm(query_states)
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
query_states = query_states.transpose(1, 2)
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
# We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
# once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
if self.is_kv_shared_layer:
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
# Device of past layer may be different from current one
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
key_states = key_states.view(hidden_shape)
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
""".lstrip("\n"),
),
]
@@ -483,18 +525,24 @@ def apply_lora_kernel_patches(
if cfg.lora_qkv_kernel:
# Query, key, value patching
# Filter out None projections (e.g. Gemma4 v_proj when attention_k_eq_v=True)
proj_names = ["q_proj", "k_proj", "v_proj"]
layer_modules = [
getattr(self_attn, name)
for name in proj_names
if getattr(self_attn, name, None) is not None
]
has_v_proj = getattr(self_attn, "v_proj", None) is not None
proj_names = (
["q_proj", "k_proj", "v_proj"]
if has_v_proj
else ["q_proj", "k_proj"]
)
layer_modules = [getattr(self_attn, name) for name in proj_names]
can_patch_qkv = all(
hasattr(module, "lora_A") for module in layer_modules
)
if can_patch_qkv:
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
if has_v_proj:
self_attn.apply_qkv = types.MethodType(
apply_lora_qkv, self_attn
)
else:
self_attn.apply_qkv = types.MethodType(apply_lora_qk, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA adapters"

View File

@@ -0,0 +1,194 @@
"""
Gemma 4 fused attention monkeypatch.
Replaces the per-layer RMSNorm + RoPE + transpose sequence with fused Triton
kernels, eliminating intermediate tensor allocations from rotate_half / apply_rotary_pos_emb
Usage:
from axolotl.monkeypatch.models.gemma4.fused_attn import patch_gemma4_fused_attn
# Pass install_shared_kv_workaround=True when activation checkpointing is enabled.
patch_gemma4_fused_attn(install_shared_kv_workaround=True)
"""
from typing import Callable
import torch
from axolotl.utils.logging import get_logger
logger = get_logger(__name__)
# Module-level dict used as a side channel for shared KV states avoiding kwarg and TLS
# to prevent memory leak on gradient checkpoint enabled training (PR #3611)
_GEMMA4_SHARED_KV_STORE: dict = {"store": None}
def _set_shared_kv_states(store):
_GEMMA4_SHARED_KV_STORE["store"] = store
def _get_shared_kv_states():
return _GEMMA4_SHARED_KV_STORE["store"]
def _make_fused_forward(original_forward):
"""Create a patched forward that uses fused RMSNorm+RoPE kernels."""
from axolotl.kernels.gemma4_fused_rope import (
fused_rms_norm_noscale,
fused_rms_norm_rope,
)
def fused_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor,
attention_mask: torch.Tensor | None,
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]] | None = None,
past_key_values=None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.gemma4.modeling_gemma4 import (
eager_attention_forward,
)
store = _get_shared_kv_states()
if store is not None:
shared_kv_states = store
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
eps = self.config.rms_norm_eps
cos, sin = position_embeddings
# ---- Projections ----
# Use apply_qkv if present (LoRA kernel patch), otherwise direct proj
has_lora_qkv = hasattr(self, "apply_qkv")
if has_lora_qkv:
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape)
else:
query_states = self.q_proj(hidden_states).view(hidden_shape)
# ---- Q path: fused q_norm + RoPE ----
query_states = fused_rms_norm_rope(
query_states,
self.q_norm.weight,
cos,
sin,
eps=eps,
)
query_states = query_states.transpose(1, 2)
# ---- K/V path ----
if self.is_kv_shared_layer:
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
if has_lora_qkv:
# apply_qkv already computed k/v projections
key_states = key_states.view(hidden_shape)
value_states = (
value_states.view(hidden_shape)
if self.v_proj is not None
else key_states
)
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
value_states = (
self.v_proj(hidden_states).view(hidden_shape)
if self.v_proj is not None
else key_states
)
# Fused k_norm + RoPE
key_states = fused_rms_norm_rope(
key_states,
self.k_norm.weight,
cos,
sin,
eps=eps,
)
key_states = key_states.transpose(1, 2)
# Fused v_norm (no scale, no RoPE)
value_states = fused_rms_norm_noscale(value_states, eps=eps)
value_states = value_states.transpose(1, 2)
if past_key_values is not None and not self.is_kv_shared_layer:
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx
)
if self.store_full_length_kv:
shared_kv_states[self.layer_idx] = key_states, value_states
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[
self.config._attn_implementation
]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=self.attention_dropout if self.training else 0.0,
scaling=self.scaling,
sliding_window=self.sliding_window,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
return fused_forward
def _patch_decoder_layer_call():
"""Strip `shared_kv_states` from decoder-layer kwargs and route via the
module-level side channel so the checkpoint partial cannot pin it (PR #3611).
"""
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextDecoderLayer
if getattr(Gemma4TextDecoderLayer, "_axolotl_shared_kv_patched", False):
return
original_call = Gemma4TextDecoderLayer.__call__
def patched_call(self, *args, **kwargs):
shared_kv = kwargs.pop("shared_kv_states", None)
# Overwrite unconditionally (including with None) so a previous step's
# dict cannot leak into a later call without shared_kv_states (PR #3611).
_set_shared_kv_states(shared_kv)
return original_call(self, *args, **kwargs)
Gemma4TextDecoderLayer.__call__ = patched_call
Gemma4TextDecoderLayer._axolotl_shared_kv_patched = True
def patch_gemma4_fused_attn(install_shared_kv_workaround: bool = False):
"""
Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels,
and optionally route `shared_kv_states` via a module-level side channel to
avoid a VRAM leak under activation checkpointing (PR #3611).
"""
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
original_forward = Gemma4TextAttention.forward
Gemma4TextAttention.forward = _make_fused_forward(original_forward)
if install_shared_kv_workaround:
_patch_decoder_layer_call()
logger.info(
"Patched Gemma4TextAttention.forward with fused RMSNorm+RoPE Triton kernels"
)
if install_shared_kv_workaround:
logger.info("Installed Gemma4 shared_kv_states side channel (PR #3611)")

View File

@@ -1,252 +0,0 @@
"""module for patching with unsloth optimizations"""
import inspect
import types
import torch
from peft import PeftModelForCausalLM
from torch import nn
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
""".lstrip("\n")
PATCHED_QKV_CODE = """
query_states, key_states, value_states = self.apply_qkv(self, hidden_states)
""".lstrip("\n")
ORIGINAL_O_CODE = """
attn_output = self.o_proj(attn_output)
""".lstrip("\n")
PATCHED_O_CODE = """
attn_output = self.apply_o(self, attn_output)
""".lstrip("\n")
def original_apply_qkv(self, hidden_states):
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
return query_states, key_states, value_states
def original_apply_o(self, hidden_states):
attn_output = self.o_proj(hidden_states)
return attn_output
def get_self_attn_code() -> str:
forward = inspect.getsource(LlamaFlashAttention2.forward)
return forward
def check_self_attn_is_patchable() -> bool:
qkv = get_self_attn_code()
qkv, _ = detab_code(qkv)
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss
def UnslothForCausalLMLoss(
logits,
labels,
vocab_size: int,
num_items_in_batch: int = None,
ignore_index: int = -100,
**kwargs,
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch
)
return loss
if model_type == "llama":
from transformers.loss import loss_utils
loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment]
else:
raise ValueError("Unsupported model type")
self_attn_lora_patched = False
def patch_self_attn_lora():
global self_attn_lora_patched
if self_attn_lora_patched:
# prevent patching multiple times
return
self_attn_forward = get_self_attn_code()
LlamaFlashAttention2._original_forward = self_attn_forward
self_attn_forward, _ = detab_code(self_attn_forward)
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original qkv code not found"
assert ORIGINAL_O_CODE in self_attn_forward, "Original o code not found"
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace(
"def forward(",
"def unsloth_attn_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in self_attn_forward:
items_to_import.append(item)
exec(
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(self_attn_forward, globals())
self_attn_lora_patched = True
LOG.info("patching unsloth attn lora")
LlamaFlashAttention2.forward = unsloth_attn_forward
def integrate_rope_embeddings():
import transformers.models.llama.modeling_llama
from unsloth.kernels.rope_embedding import fast_rope_embedding
def apply_rotary_pos_emb(
q,
k,
cos,
sin,
position_ids=None,
unsqueeze_dim=1,
):
return fast_rope_embedding(q, k, cos, sin)
LOG.info("patching unsloth RoPE embeddings")
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
from unsloth.kernels import apply_lora_mlp_swiglu
apply_lora_mlp = apply_lora_mlp_swiglu
elif peft_model.base_model.config.model_type == "gemma":
from unsloth.kernels import apply_lora_mlp_geglu_approx
apply_lora_mlp = apply_lora_mlp_geglu_approx
else:
raise NotImplementedError(
f"Model type {peft_model.base_model.config.model_type} not supported"
)
for idx, layer in enumerate(peft_model.model.model.layers):
layer_modules = [
getattr(layer.mlp, linear_proj)
for linear_proj in ["gate_proj", "up_proj", "down_proj"]
]
is_mlp_lora = all(hasattr(module, "lora_A") for module in layer_modules)
mlp_no_bias = all(
getattr(module, "base_layer", module).bias is None
for module in layer_modules
)
mlp_not_dora = all(
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
else:
LOG.warning(f"unable to apply unsloth lora mlp patch to layer {idx}")
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
from unsloth.kernels import apply_lora_o, apply_lora_qkv
for idx, layer in enumerate(peft_model.model.model.layers):
if cfg.unsloth_lora_qkv:
layer_modules = [
getattr(layer.self_attn, linear_proj)
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
is_qkv_lora = all(hasattr(module, "lora_A") for module in layer_modules)
qkv_no_bias = all(
getattr(module, "base_layer", module).bias is None
for module in layer_modules
)
qkv_not_dora = all(
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if is_qkv_lora and qkv_no_bias and qkv_not_dora:
layer.self_attn.apply_qkv = apply_lora_qkv
else:
layer.self_attn.apply_qkv = original_apply_qkv
LOG.warning(f"unable to apply unsloth lora qkv patch to layer {idx}")
if cfg.unsloth_lora_o:
layer_modules = [
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
]
is_o_lora = all(hasattr(module, "lora_A") for module in layer_modules)
o_no_bias = all(
getattr(module, "base_layer", module).bias is None
for module in layer_modules
)
o_not_dora = all(
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if is_o_lora and o_no_bias and o_not_dora:
layer.self_attn.apply_o = apply_lora_o
else:
layer.self_attn.apply_o = original_apply_o
LOG.warning(f"unable to apply unsloth lora o_proj patch to layer {idx}")
def patch_unsloth_layernorm():
try:
import transformers.models.llama.modeling_llama
from unsloth.kernels.rms_layernorm import Fast_RMS_Layernorm
class LlamaRMSNorm(nn.Module):
"""LlamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
return Fast_RMS_Layernorm.apply(
hidden_states, self.weight, self.variance_epsilon, False
)
LOG.info("patching with unsloth.kernels.rms_layernorm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.warning("missing unsloth library")

View File

@@ -315,6 +315,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
self._validate_eot_and_eos_tokens()
# Pre-cache EOT token IDs to avoid re-encoding on every call
self._eot_token_ids = set()
for token in self.eot_tokens:
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
if len(token_ids) == 1:
self._eot_token_ids.add(token_ids[0])
def _validate_eot_and_eos_tokens(self):
"""
- Validates that EOT tokens (or eos_token) are in the chat_template
@@ -471,6 +478,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
content = turn.get("content")
train_turn = turn.get("training")
train_detail = turn.get("training_detail")
reasoning_train_detail = turn.get("reasoning_training_detail")
LOG.debug(
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
@@ -479,8 +487,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
should_train = None
if train_turn is not None:
should_train = train_turn
elif train_detail is not None:
should_train = bool(train_detail)
elif train_detail is not None or reasoning_train_detail is not None:
should_train = bool(train_detail) or bool(reasoning_train_detail)
else:
should_train = self.train_on_inputs or role in self.roles_to_train
@@ -500,15 +508,26 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
continue
thinking_key = self.prompter.template_thinking_key
has_reasoning = thinking_key and turn.get(thinking_key) is not None
has_any_detail = train_detail or reasoning_train_detail
# When train_detail is present and the turn has reasoning_content,
# use content_only=True so find_turn returns content-only boundaries
# (excluding reasoning_content + template separator tokens).
use_content_only = bool(has_any_detail and has_reasoning)
turn_start_idx, turn_end_idx = self.find_turn(
turns=turns, turn_idx=index, tools=tools
turns=turns,
turn_idx=index,
tools=tools,
content_only=use_content_only,
)
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
if train_detail:
# Block multi-content for now
if not isinstance(content, str):
raise ValueError(
"`train_detail` is not supported when `content` is not a string."
@@ -526,7 +545,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
LOG.debug(
f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}"
)
else:
elif not reasoning_train_detail:
# No per-part detail on either field — train the whole span
labels[turn_start_idx:turn_end_idx] = input_ids[
turn_start_idx:turn_end_idx
]
@@ -534,6 +554,32 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
)
# Handle reasoning_content training_detail separately
if should_train and reasoning_train_detail and has_reasoning:
reasoning_text = turn[thinking_key]
if not isinstance(reasoning_text, str):
raise ValueError(
"`reasoning_training_detail` is not supported when reasoning_content is not a string."
)
reasoning_start, reasoning_end = self.find_turn(
turns=turns,
turn_idx=index,
tools=tools,
reasoning_only=True,
)
if reasoning_start != -1 and reasoning_end != -1:
token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore
reasoning_text, reasoning_train_detail
)
LOG.debug(f"Reasoning token offsets: {token_offsets}")
for i, offset in enumerate(token_offsets):
if offset != IGNORE_TOKEN_ID and reasoning_start + i < len(
input_ids
):
labels[reasoning_start + i] = input_ids[reasoning_start + i]
LOG.debug(f"Labels after processing turn {index}: {labels}")
# Handle special tokens (EOT and EOS)
@@ -593,28 +639,31 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def find_first_eot_token(self, input_ids, start_idx):
"""Find the first EOT token in the input_ids starting from start_idx."""
# Get token IDs for all EOT tokens
eot_token_ids = []
for token in self.eot_tokens:
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
if len(token_ids) != 1:
raise ValueError(
f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config."
)
eot_token_ids.append(token_ids[0]) # Use the last token ID if multiple
# Search for any of the EOT token IDs
# Use pre-cached EOT token IDs (computed once in __init__)
for i in range(start_idx, len(input_ids)):
if input_ids[i] in eot_token_ids:
if input_ids[i] in self._eot_token_ids:
return i
return -1
def find_turn(
self, turns: list[dict], turn_idx: int, tools: list[dict] | None = None
self,
turns: list[dict],
turn_idx: int,
tools: list[dict] | None = None,
content_only: bool = False,
reasoning_only: bool = False,
):
"""
Locate the starting and ending indices of the specified turn in a conversation.
Args:
content_only: If True and the turn has reasoning_content (template_thinking_key),
preserve reasoning_content in the dummy turn so the diff only captures the
content field boundaries. This is needed for correct training_detail alignment
when reasoning_content is present.
reasoning_only: If True, preserve content in the dummy turn and replace
reasoning_content with a dummy, so the diff only captures the
reasoning_content field boundaries.
"""
if turn_idx >= len(turns):
@@ -628,10 +677,26 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
):
return -1, -1
empty_turn = {
"role": turns[turn_idx].get("role"),
"content": "[[dummy_message]]",
}
thinking_key = self.prompter.template_thinking_key
if reasoning_only:
# Keep content as-is, replace reasoning with dummy
empty_turn = {
"role": turns[turn_idx].get("role"),
"content": turns[turn_idx].get("content", ""),
}
if thinking_key and thinking_key in turns[turn_idx]:
empty_turn[thinking_key] = "[[dummy_reasoning]]"
else:
empty_turn = {
"role": turns[turn_idx].get("role"),
"content": "[[dummy_message]]",
}
# When content_only is True, copy reasoning_content to the dummy turn so
# the diff only captures the content field (not reasoning + separator).
if content_only and thinking_key and thinking_key in turns[turn_idx]:
empty_turn[thinking_key] = turns[turn_idx][thinking_key]
# Create conversation versions
turns_with_empty = turns[:turn_idx] + [empty_turn]
@@ -697,6 +762,94 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return start_idx, end_idx
@staticmethod
def _convert_content_parts(
content,
) -> tuple[str, list[dict] | None] | None:
"""Convert list content to concatenated string + optional training_detail.
When content is a list of dicts (content parts), each part can specify:
- ``text``, ``content``, or ``value``: the text string
- ``train`` (bool) or ``weight`` (0/1): per-part training flag
Returns ``(concatenated_text, training_details_or_None)`` if content was
a list, or ``None`` if content was not a list (no conversion needed).
.. note::
**Whitespace at part boundaries matters.** BPE tokenizers prepend
spaces to word tokens (e.g. ``" answer"`` is one token). Always
split BEFORE spaces::
GOOD: ["Let me think...", " The answer is 4."]
BAD: ["Let me think... ", "The answer is 4."]
Tokens that straddle a boundary are conservatively masked.
Newlines typically merge with preceding punctuation (``":\\n"`` is
one token), so keep newlines with the preceding part.
"""
if not isinstance(content, list):
return None
text_parts: list[str] = []
training_details: list[dict] = []
has_explicit_training = False
offset = 0
for part in content:
if isinstance(part, dict):
# Extract text (HF uses "text", also support "content"/"value")
text = (
part.get("text") or part.get("content") or part.get("value") or ""
)
text_parts.append(text)
# Check for per-part training flags
part_train = part.get("train")
part_weight = part.get("weight")
if part_train is not None or part_weight is not None:
has_explicit_training = True
train = (
part_train
if part_train is not None
else (part_weight not in (0, 0.0))
)
else:
train = True # default trainable, gated by turn-level should_train
if text:
training_details.append(
{
"begin_offset": offset,
"end_offset": offset + len(text) - 1,
"train": train,
}
)
offset += len(text)
# Warn about trailing whitespace at boundaries between parts with
# different training flags — this almost always causes token straddling
if has_explicit_training and len(training_details) > 1:
for i in range(len(training_details) - 1):
cur = training_details[i]
nxt = training_details[i + 1]
if cur["train"] != nxt["train"]:
boundary_text = text_parts[i]
if boundary_text and boundary_text[-1] in (" ", "\t"):
LOG.warning(
"Content part %d ends with whitespace at a train/mask boundary. "
"BPE tokenizers typically prepend spaces to word tokens, so "
"the space will merge with the next part's first word and the "
"resulting token will be MASKED (not trained). Move the "
"whitespace to the start of the next content part instead. "
"Part text: %r",
i,
boundary_text[-20:],
)
concatenated = "".join(text_parts)
details = training_details if has_explicit_training else None
return concatenated, details
def get_conversation_thread(self, prompt):
turns = []
@@ -723,6 +876,23 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if training_detail is not None:
turn["training_detail"] = training_detail
# Convert list content/reasoning_content to string + auto-generated
# training_detail. See _convert_content_parts for whitespace guidance.
content_result = self._convert_content_parts(turn.get("content"))
if content_result is not None:
turn["content"] = content_result[0]
if content_result[1] is not None:
turn["training_detail"] = content_result[1]
# Also convert reasoning_content (template_thinking_key) if it's a list
thinking_key = self.prompter.template_thinking_key
if thinking_key and thinking_key in turn:
reasoning_result = self._convert_content_parts(turn[thinking_key])
if reasoning_result is not None:
turn[thinking_key] = reasoning_result[0]
if reasoning_result[1] is not None:
turn["reasoning_training_detail"] = reasoning_result[1]
turns.append(turn)
if self.prompter.drop_system_message and turns[0]["role"] == "system":

View File

@@ -160,29 +160,16 @@ class TelemetryManager:
if not is_main_process():
return False
# Parse relevant env vars
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
do_not_track = os.getenv("DO_NOT_TRACK")
def is_truthy_env(var_name: str) -> bool:
value = os.getenv(var_name)
if value is None:
return False
return value.strip().lower() in ("1", "true")
# Default to enabled (opt-out model)
if axolotl_do_not_track is None or axolotl_do_not_track.lower() not in (
"0",
"1",
"false",
"true",
):
return True
if do_not_track is None:
do_not_track = "0"
# Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled
enabled = axolotl_do_not_track.lower() not in (
"1",
"true",
) and do_not_track.lower() not in ("1", "true")
return enabled
# Telemetry is enabled by default unless either opt-out var is set
return not (
is_truthy_env("AXOLOTL_DO_NOT_TRACK") or is_truthy_env("DO_NOT_TRACK")
)
def _load_whitelist(self) -> dict:
"""Load HuggingFace Hub organization whitelist"""

View File

@@ -36,7 +36,7 @@ from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.freeze import freeze_layers_except, freeze_mm_modules
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.train import determine_last_checkpoint
@@ -114,6 +114,10 @@ def setup_model_and_tokenizer(
):
model.enable_input_require_grads()
# Freeze multimodal modules for text-only training of multimodal models
if cfg.freeze_mm_modules:
freeze_mm_modules(model)
return model, tokenizer, peft_config, processor
@@ -225,6 +229,28 @@ def execute_training(
PLUGIN_MANAGER.post_train(cfg, trainer.model)
def _rename_fsdp_merged_to_adapter(merged_dir: Path):
"""Rename model*.safetensors files to adapter_model* in place.
Also rewrites the index JSON weight_map if sharded output was produced.
"""
for file in sorted(merged_dir.iterdir()):
if file.name.startswith("model") and ".safetensors" in file.name:
file.rename(merged_dir / file.name.replace("model", "adapter_model", 1))
index = merged_dir / "adapter_model.safetensors.index.json"
if index.exists():
data = json.loads(index.read_text(encoding="utf-8"))
if "weight_map" in data:
data["weight_map"] = {
k: v.replace("model", "adapter_model", 1)
for k, v in data["weight_map"].items()
}
index.write_text(
json.dumps(data, indent=2, sort_keys=True) + "\n", encoding="utf-8"
)
def save_trained_model(
cfg: DictDefault,
trainer: Any,
@@ -294,12 +320,17 @@ def save_trained_model(
)
trainer.accelerator.wait_for_everyone()
if trainer.accelerator.is_main_process:
# move all files in merged_path to cfg.output_dir
# FSDP checkpoints for PEFT only contain adapter weights;
# rename model* → adapter_model* so it loads correctly.
is_peft = cfg.adapter and not cfg.relora
if is_peft:
_rename_fsdp_merged_to_adapter(Path(merged_path))
for merged_file in Path(merged_path).iterdir():
if (Path(cfg.output_dir) / merged_file.name).exists():
(Path(cfg.output_dir) / merged_file.name).unlink()
shutil.move(str(merged_file), cfg.output_dir)
shutil.rmtree(merged_path) # remove what should be an empty dir
dest = Path(cfg.output_dir) / merged_file.name
if dest.exists():
dest.unlink()
shutil.move(str(merged_file), dest)
shutil.rmtree(merged_path)
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
# cleanup the FSDP prefix in the model config.json
if trainer.accelerator.is_main_process:

View File

@@ -98,6 +98,56 @@ class SaveModelOnFirstStepCallback(TrainerCallback):
return control
class SkipEvalOnResumeCallback(TrainerCallback):
"""Skip the redundant evaluation that fires when resuming from a checkpoint
whose step aligns with ``eval_steps``.
When HuggingFace Trainer resumes, it restores ``global_step`` from the
checkpoint and immediately triggers ``_maybe_log_save_evaluate`` for that
step. Because the evaluation was already performed during the original
run, repeating it wastes time and pollutes metric logs.
This callback records the ``global_step`` at the start of training (i.e.
the checkpoint step when resuming, or 0 for a fresh run) and suppresses
any evaluation request on that exact step.
"""
def __init__(self):
super().__init__()
self._resume_step: int | None = None
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
):
# ``global_step`` is already restored from the checkpoint at this
# point. For a fresh run it will be 0, so the guard below becomes a
# no-op.
self._resume_step = state.global_step
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
) -> TrainerControl:
if (
self._resume_step
and state.global_step <= self._resume_step
and control.should_evaluate
):
LOG.info(
"Skipping evaluation at step %d (already completed before resume)",
state.global_step,
)
control.should_evaluate = False
return control
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [

View File

@@ -1,7 +1,19 @@
{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0].role == 'system' %}
{{- messages[0].content + '\n\n' }}
{%- if messages[0].content is string %}
{{- messages[0].content + '\n\n' }}
{%- else %}
{%- for part in messages[0].content %}
{%- if part is mapping %}
{%- set system_text = part.get('text') or part.get('content') or part.get('value') %}
{%- if system_text %}{{- system_text }}{%- endif %}
{%- elif part is string %}
{{- part }}
{%- endif %}
{%- endfor %}
{{- '\n\n' }}
{%- endif %}
{%- endif %}
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
@@ -11,7 +23,20 @@
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- if messages[0].content is string %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- else %}
{{- '<|im_start|>system\n' }}
{%- for part in messages[0].content %}
{%- if part is mapping %}
{%- set system_text = part.get('text') or part.get('content') or part.get('value') %}
{%- if system_text %}{{- system_text }}{%- endif %}
{%- elif part is string %}
{{- part }}
{%- endif %}
{%- endfor %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}

View File

@@ -268,6 +268,37 @@ def normalize_config(cfg):
):
cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}
# Gemma4 requires use_reentrant=False for DDP (shared per-layer norms cause
# "marked ready twice" errors with reentrant checkpointing) and
# ddp_find_unused_parameters=True (per_layer_projection LoRA params may not
# receive gradients on every step).
if cfg.model_config_type == "gemma4":
if cfg.gradient_checkpointing:
if cfg.gradient_checkpointing_kwargs is None:
cfg.gradient_checkpointing_kwargs = {}
if cfg.gradient_checkpointing_kwargs.get("use_reentrant") is not False:
LOG.warning(
"Gemma4 requires use_reentrant=False for gradient checkpointing "
"in distributed training. Setting use_reentrant=False."
)
cfg.gradient_checkpointing_kwargs["use_reentrant"] = False
if cfg.ddp and cfg.ddp_find_unused_parameters is None:
if cfg.activation_offloading is True:
# activation_offloading uses checkpoint wrappers that conflict
# with find_unused_parameters (causes "marked ready twice").
# Use freeze_mm_modules instead to eliminate unused params.
LOG.info(
"Gemma4 + DDP + activation_offloading: skipping "
"ddp_find_unused_parameters (use freeze_mm_modules to "
"handle unused vision/audio params)."
)
else:
LOG.warning(
"Gemma4 requires ddp_find_unused_parameters=True for DDP. "
"Auto-enabling."
)
cfg.ddp_find_unused_parameters = True
log_gpu_memory_usage(LOG, "baseline", cfg.device)

View File

@@ -180,6 +180,119 @@ def _drop_long_sequences(
raise ValueError("Unknown RL type")
def _raise_on_long_sequences(
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
) -> bool:
"""Check sequence length and raise ValueError if exceeded.
Used as a filter function for ``excess_length_strategy: raise``.
Args:
sample: Dataset sample to check.
rl: Reinforcement learning type.
tokenizer: Tokenizer for length calculation.
sequence_len: Maximum allowed sequence length.
Returns:
Always True (raises before returning False).
Raises:
ValueError: If any sample exceeds the configured sequence length.
"""
is_valid = _drop_long_sequences(sample, rl, tokenizer, sequence_len)
if not is_valid:
raise ValueError(
f"Sample exceeds configured sequence_len ({sequence_len}). "
"Set `excess_length_strategy: drop` or `excess_length_strategy: truncate` "
"to handle long sequences automatically."
)
return True
def _truncate_long_sequences_rl(
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
) -> dict[str, Any]:
"""Truncate RL samples that exceed maximum sequence length.
For preference datasets (DPO/IPO/ORPO/SIMPO), truncates chosen and rejected
responses to fit within ``sequence_len`` when combined with the prompt.
For KTO, truncates the completion similarly.
GRPO/GDPO/EBFT samples are returned unchanged.
Samples where the prompt alone exceeds ``sequence_len`` cannot be
meaningfully truncated and are returned unchanged. The caller should
follow up with a drop filter to remove them.
Args:
sample: Dataset sample to potentially truncate.
rl: Reinforcement learning type.
tokenizer: Tokenizer for encoding/decoding.
sequence_len: Maximum allowed sequence length.
Returns:
The sample with text fields truncated to fit within sequence_len.
"""
# Fast path: if sample already fits, return unchanged (avoids decode overhead)
if _drop_long_sequences(sample, rl, tokenizer, sequence_len):
return sample
if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}:
if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
):
raise ValueError(
"Prompt, chosen and rejected keys are required for DPO/ORPO datasets"
)
prompt_ids = tokenizer(sample["prompt"], add_special_tokens=False)["input_ids"]
chosen_ids = tokenizer(sample["chosen"], add_special_tokens=False)["input_ids"]
rejected_ids = tokenizer(sample["rejected"], add_special_tokens=False)[
"input_ids"
]
max_response_len = sequence_len - len(prompt_ids)
if max_response_len <= 0:
# Prompt alone exceeds limit; cannot meaningfully truncate.
# Returned unchanged — the follow-up drop filter will remove it.
return sample
updates: dict[str, Any] = {}
if len(chosen_ids) > max_response_len:
updates["chosen"] = tokenizer.decode(
chosen_ids[:max_response_len], skip_special_tokens=False
)
if len(rejected_ids) > max_response_len:
updates["rejected"] = tokenizer.decode(
rejected_ids[:max_response_len], skip_special_tokens=False
)
if updates:
sample = {**sample, **updates}
elif rl is RLType.KTO:
if not (sample.get("prompt") and sample.get("completion")):
raise ValueError("Prompt and completion keys are required for KTO datasets")
prompt_ids = tokenizer(sample["prompt"], add_special_tokens=False)["input_ids"]
completion_ids = tokenizer(sample["completion"], add_special_tokens=False)[
"input_ids"
]
max_completion_len = sequence_len - len(prompt_ids)
if max_completion_len <= 0:
return sample
if len(completion_ids) > max_completion_len:
sample = {
**sample,
"completion": tokenizer.decode(
completion_ids[:max_completion_len], skip_special_tokens=False
),
}
# GRPO/GDPO/EBFT: no truncation needed (responses generated at runtime)
return sample
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
"""Load and process dataset split for RL training.
@@ -243,23 +356,77 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
split_datasets[i] = dataset
if not cfg.skip_prepare_dataset:
drop_long = partial(
_drop_long_sequences,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
if excess_length_strategy == "truncate":
truncate_fn = partial(
_truncate_long_sequences_rl,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].map(
truncate_fn,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Truncating Long Sequences",
)
# Drop samples that could not be truncated (e.g. prompt
# alone exceeds sequence_len)
drop_long = partial(
_drop_long_sequences,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Un-truncatable Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(
f"Dropped {dropped} samples from dataset index {i} "
f"that could not be truncated to fit sequence_len "
f"(prompt alone exceeds limit)"
)
elif excess_length_strategy == "raise":
raise_fn = partial(
_raise_on_long_sequences,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
split_datasets[i] = split_datasets[i].filter(
raise_fn,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Checking Sequence Lengths",
)
else: # "drop" (default)
drop_long = partial(
_drop_long_sequences,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(
f"Dropped {dropped} long samples from dataset index {i}"
)
# Merge datasets
dataset = merge_datasets(split_datasets, cfg)

View File

@@ -10,6 +10,44 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# Top-level module name prefixes that belong to vision/audio/multimodal encoders
# rather than the language backbone. These are matched against the first component
# of each ``named_parameter`` path (e.g. "model.vision_tower." -> "vision_tower").
_MM_MODULE_PREFIXES = (
"vision_tower",
"vision_model",
"vision_encoder",
"embed_vision",
"multi_modal_projector",
"visual",
"audio_tower",
"audio_model",
"embed_audio",
)
def freeze_mm_modules(model):
"""Freeze all vision/audio/multimodal-projector parameters.
Iterates over ``model.named_parameters()`` and sets ``requires_grad = False``
for any parameter whose name contains a known vision/audio module prefix.
This is useful when fine-tuning only the language backbone of a multimodal
model and avoids the need for ``ddp_find_unused_parameters=True``.
"""
frozen_count = 0
for name, param in model.named_parameters():
# Check if any path component matches a vision/audio prefix
parts = name.split(".")
if any(part in _MM_MODULE_PREFIXES for part in parts):
if param.requires_grad:
param.requires_grad = False
frozen_count += 1
if is_main_process():
LOG.debug(f"freeze_mm_modules: froze {name}")
if is_main_process():
LOG.info(f"freeze_mm_modules: froze {frozen_count} vision/audio parameters")
def freeze_layers_except(model, regex_patterns):
"""

View File

@@ -578,6 +578,17 @@ class AxolotlInputConfig(
},
)
freeze_mm_modules: bool | None = Field(
default=None,
json_schema_extra={
"description": "Freeze multimodal encoder parameters (vision, audio, etc.) for "
"text-only training of multimodal models. When True, parameters belonging to "
"vision towers, audio towers, multimodal projectors, and similar non-language "
"modules are frozen (requires_grad=False). This allows DDP training without "
"ddp_find_unused_parameters=True."
},
)
unfrozen_parameters: list[str] | None = Field(
default=None,
json_schema_extra={
@@ -766,6 +777,15 @@ class AxolotlInputConfig(
},
)
gemma4_hybrid_attn_impl: bool | None = Field(
default=None,
json_schema_extra={
"description": "Use hybrid attention for Gemma 4: flash_attention_2 for sliding window layers "
"and sdpa for global (full_attention) layers. Global layers have head_dim=512 which "
"exceeds flash attention's supported size."
},
)
experts_implementation: str | None = Field(
default=None,
json_schema_extra={
@@ -803,13 +823,6 @@ class AxolotlInputConfig(
},
)
unsloth_cross_entropy_loss: bool | None = None
unsloth_lora_mlp: bool | None = None
unsloth_lora_qkv: bool | None = None
unsloth_lora_o: bool | None = None
unsloth_rms_norm: bool | None = None
unsloth_rope: bool | None = None
lora_mlp_kernel: bool | None = Field(
default=None,
json_schema_extra={
@@ -1449,21 +1462,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return data
@model_validator(mode="before")
@classmethod
def check_multigpu_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
capabilities = data.get("capabilities")
if capabilities and capabilities.get("n_gpu", 0) > 1:
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
)
return data
@model_validator(mode="before")
@classmethod
def check_multigpu_lora_kernels(cls, data):
@@ -1517,8 +1515,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
# RL trainers not tested so don't enable kernels by default
return data
if data.get("adapter") in ["lora", "qlora"]:
# Skip if already set, using unsloth optimizations, or using 8-bit
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
# Skip if already set or using 8-bit
kernel_fields = [
"lora_mlp_kernel",
"lora_qkv_kernel",
@@ -1527,7 +1524,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
]
if (
any(data.get(k) is not None for k in kernel_fields)
or any(data.get(k) for k in unsloth_fields)
or data.get("adapter") == "lora"
and data.get("load_in_8bit")
):

Some files were not shown because too many files have changed in this diff Show More