Compare commits

..

38 Commits

Author SHA1 Message Date
Wing Lian
d17ed89a3c add missing file 2026-04-21 08:44:01 -04:00
Wing Lian
02e4f2350d fixes for scattermoe from latest peft upgrade 2026-04-21 08:00:16 -04:00
Wing Lian
4195605ab2 fix test dims 2026-04-21 00:44:26 +00:00
Wing Lian
37acb28d02 fix einsum dims 2026-04-20 23:09:47 +00:00
Wing Lian
4a5281e61a Fix shape 2026-04-19 01:53:05 +00:00
Wing Lian
a892d8cce1 chore: lint 2026-04-17 17:48:26 +00:00
Wing Lian
78de2919a6 tiled mlp fix for gemma4 2026-04-16 13:24:41 +00:00
Wing Lian
28283ff373 revert shared_kv_states workaround with transformers 5.5.4 2026-04-15 13:32:59 +00:00
Wing Lian
dc16859983 [gemma4] fix fused RMSNorm+RoPE on hybrid attention models
- Kernel: fused_rms_norm_rope crashed when cos.shape[-1] < x.shape[-1].
  Triton forward/backward take an n_rot runtime arg that restricts
  rotate_half to [0, n_rot) and treats trailing cols as RMSNorm-only
  pass-through (cos=1, sin=0 defaults). Wrapper also expands cos/sin
  that broadcast over batch.

- Forward: _make_fused_forward used a stale shared_kv_states kwarg the
  current decoder layer no longer passes. Now mirrors stock attention,
  reading/writing past_key_values.shared_layers.
2026-04-15 13:27:31 +00:00
Wing Lian
d4e9cf2eec lint 2026-04-15 13:27:30 +00:00
Wing Lian
53391a10d7 vllm-serve-lora add /v1/completions route + worker pipe lock
The LoRA vllm-serve wrapper only exposed /v1/chat/completions, but
retrace's SWE agent server uses the token-id-aware /v1/completions
endpoint so it can feed raw prompt_token_ids + track per-token
logprobs across multi-turn rollouts. Add the route, mirroring the
shape of /v1/chat/completions but routing to the vLLM worker's
generate() method so prompt_token_ids are passed through as-is.

Also add a worker_pipe_lock around conn.send/conn.recv. The
multiprocessing.Connection to the vLLM worker is a single shared
full-duplex pipe; concurrent HTTP requests interleave pickle frames
on the wire and corrupt the stream (observed as
UnpicklingError: pickle data was truncated, surfacing as 500s).
The agent server fires ~8 concurrent rollout requests at once, so
this was a hard blocker for any multi-concurrent workload. Serialize
access to the pipe per-request round-trip.
2026-04-15 13:27:30 +00:00
Wing Lian
7617b951a8 make _maybe_sync_vllm_weights actually fire in sync mode
Two bugs in ``AsyncGRPOTrainer._maybe_sync_vllm_weights`` plus a
companion bug in the sync-hook patch site that together neutralized
LoRA weight sync entirely whenever ``async_prefetch=False`` was
combined with NeMo Gym's data-producer path:

1. ``_maybe_sync_vllm_weights`` had ``if not async_prefetch: return``
   at the top. The original design assumed sync mode would fall back
   to TRL's stock per-step ``sync_weights`` call inside
   ``_generate_single_turn`` — true for vanilla GRPO but FALSE in
   NeMo Gym multi-turn, where ``NemoGymDataProducer`` calls the agent
   server directly and ``_generate_single_turn`` is never invoked.
   Result: no sync ever happened in NeMo Gym sync mode.

2. ``step % vllm_sync_interval`` would TypeError on the first call if
   ``vllm_sync_interval`` was unset (the default for any config that
   doesn't explicitly set it).

3. The ``_generate_single_turn`` patch installed
   ``vllm_generation.sync_weights = lambda: None`` unconditionally
   for vllm_lora_sync runs. That's correct in async-prefetch mode
   (BG thread can't safely sync) but wrong in sync mode: TRL's
   per-step auto-sync inside ``_generate_single_turn`` was the
   fallback that the early return in (1) was assuming, and the
   no-op patch was killing it.

Fix:
  - Drop the ``not async_prefetch`` early return; ``_maybe_sync_vllm_weights``
    is now the canonical sync trigger and runs in both modes from
    ``_prepare_inputs_with_data_producer`` / ``_prepare_inputs_legacy_async``.
  - Default ``vllm_sync_interval`` to 1 when unset.
  - In the ``_generate_single_turn`` patch, route sync_weights to
    ``_sync_lora_adapter`` in sync mode (and keep the lambda no-op
    in async mode for the BG-thread safety reason).
2026-04-15 13:27:30 +00:00
Wing Lian
e993ed5208 retry head-server probe with longer timeout
``get_server_configs`` was hardcoded to a 5s timeout with no retry.
That's empirically too tight to survive a kill-and-relaunch cycle:
when the agent server is finishing in-flight rollouts from a prior
run, it can take 10-30s to respond to /global_config_dict_yaml, and
the trainer would crash at startup with a ReadTimeoutError.

Bump the per-attempt timeout to 30s and retry up to 3 times with a
2s/4s backoff. The retry intentionally raises a RuntimeError after
the third failure rather than returning empty config — silent
failure here would let training proceed with no agent servers
discovered, which is also a no-op trainer.
2026-04-15 13:27:30 +00:00
Wing Lian
69f165b39b probe vLLM weight-sync routes and select transport per server
The plugin used to unconditionally monkey-patch
VLLMClient.init_communicator to a no-op AND silently no-op
sync_weights when vllm_lora_sync was off. Combined, this turned the
trainer into a functional no-op whenever (a) the user ran NeMo Gym
+ LoRA without remembering to set vllm_lora_sync=true or (b) the
user ran NeMo Gym + full fine-tune (which had no working sync path
under the old code).

Replace both patches with:

1. A probe of the configured vLLM server's /openapi.json at
   pre_model_load. Three transports are recognized:
     - NCCL (/init_communicator/ + /update_named_param/) — TRL serve
       and axolotl vllm-serve both expose this
     - LoRA filesystem (/v1/load_lora_adapter or /set_lora_adapter/)
     - HTTP base64 full-weight (/http_update_weights/) — axolotl
       vllm-serve only

2. A pure-logic ``select_weight_sync_transport`` that picks the
   right one for (server caps × adapter type).

3. ``init_communicator`` is only patched out when the server has no
   NCCL routes; against TRL/axolotl serve modules it stays live so
   full-finetune NCCL sync works.

4. ``post_trainer_create`` uses the selection table to install LoRA
   filesystem sync OR leave the standard NCCL flow alone OR raise
   NotImplementedError (HTTP — pending) OR raise a precise diagnosis
   when no transport is viable. No more silent no-op trainers.
2026-04-15 13:27:30 +00:00
Wing Lian
80a97f192b validate batch shape against num_generations at config time
Surfaces a class of GRPO config errors at axolotl-train startup instead
of letting them bubble out of GRPOTrainer.__init__ after the model loads.
Three checks under RLValidationMixin.check_grpo_batch_size_divisibility:

  - effective generation_batch_size (or mb*GA fallback) must be divisible
    by trl.num_generations, with a hint pointing at the smallest GA bump
    that fixes the violation
  - num_generations >= 2 (group-relative advantage needs variance; with
    num_gen=1 the policy never updates)
  - When world_size > 1, effective gbs >= num_generations * world_size

11 unit tests cover the table: divisible/non-divisible, explicit and
implicit gbs, multi-rank constraint, GRPO-disabled passthrough, and
unset num_generations.
2026-04-15 13:27:30 +00:00
Wing Lian
323da791eb bump transformers to 5.5.4 and trl to latest 1.1.0 (#3603)
* bump transformers to 5.5.4 and trl to latest 1.1.0

* more upgrades

* update peft too

* adapt lora_merge to peft 0.19 layer config API

PEFT 0.19 requires a LoraConfig object on Linear/ParamWrapper/Conv
layer constructors and moved use_rslora, use_dora, fan_in_fan_out,
lora_dropout, and lora_bias into that config. Build the config
per branch in _build_peft_layer_and_get_delta so the merge utility
works with the upgraded peft.

* allow lora_dropout on mixed attention+MoE configs under peft 0.19

PEFT 0.19's convert_peft_config_for_transformers auto-remaps old MoE
target_modules (w1/w2/w3 on Mixtral, etc.) into target_parameters for
transformers v5's fused 3D expert Parameters. Those targets get wrapped
with ParamWrapper, which rejects lora_dropout != 0 because the 3D
einsum can't factor dropout out of lora_B(lora_A(dropout(x))).

Monkeypatch ParamWrapper.__init__ to internally use a copy of the
LoraConfig with lora_dropout=0, so its dropout slot becomes nn.Identity
while the shared config still delivers real dropout to sibling Linear
LoRA layers (attention q/k/v/o). A probe runs the same conversion on a
deep copy to detect the situation and emit a warning before patching.
2026-04-15 09:27:03 -04:00
NanoCode012
6990478163 fix: rename model to adapter_model for fsdp sharded final model (#3585)
* fix: rename model to adapter_model for fsdp sharded final model

* fix: follow upstream transformer shard size

* fix: handle multiple model files

* fix redundant condition, tighten to safetensors, keep shard size small

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-12 20:51:30 -04:00
ゆり
63a58cfec1 feat: support excess_length_strategy for RL trainers (#3578) [skip ci]
* feat: support excess_length_strategy for RL trainers

Previously, RL data loading always dropped sequences exceeding
sequence_len. This adds support for the existing `excess_length_strategy`
config option (`drop`, `truncate`, `raise`) in RL training pipelines,
matching the behavior already available for SFT.

- `drop` (default): unchanged behavior, filters out long samples
- `truncate`: tokenizes text components, truncates responses to fit
  within sequence_len while preserving the full prompt, then decodes
  back to text. Handles DPO/IPO/ORPO/SIMPO and KTO datasets.
- `raise`: raises ValueError if any sample exceeds sequence_len

Closes #3547

* improve RL truncation strategy robustness and performance

---------

Co-authored-by: yurekami <yurekami@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-12 20:51:10 -04:00
madScientist10
3985ec2f67 feat: add FineGrainedFP8Config support for model quantization (#3587) [skip ci]
Allow loading FP8-quantized models (e.g. Mistral-Small-4-119B) with
FineGrainedFP8Config and optional dequantize kwarg for full fine-tuning.

Made-with: Cursor
2026-04-12 20:50:37 -04:00
Joaquin Hui
a44edda6d7 Skip redundant evaluation when resuming from checkpoint (#3575) [skip ci]
* Skip redundant evaluation when resuming from checkpoint

* add condition check for adding callback

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-12 20:50:15 -04:00
Wing Lian
66c3e5a3fd better handling of dora merge on Conv layers in Qwen 3.5 (#3599)
* better handling of dora merge on Conv layers in Qwen 3.5

* address issues from code review

* stricter efficient merges for dora since we now have meta model to reference
2026-04-12 10:57:45 -04:00
Wing Lian
b8358aa5ab [gemma4] use mixed Flash Attention and SDPA and add fused RMSNorm+RoPE Triton kernels (#3598) 2026-04-12 10:29:55 -04:00
Joaquin Hui
e079cf16a2 qwen3_5.jinja: handle list content on system messages (#3595) [skip ci]
* qwen3_5.jinja: handle list content on system messages

The system message branch used string concatenation on
messages[0].content, which breaks when the first system message uses
the OpenAI-style list-of-parts format that multimodal datasets require.
User and assistant branches already handle both string and list content,
but the system branch did not.

Check whether content is a string and fall back to iterating over parts
when it is a list, matching the pattern used for user messages.

Fixes #3590

* Address pr for other content types

---------

Co-authored-by: Joaquin Hui Gomez <joaquinhuigomez@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-12 00:58:58 -04:00
Wing Lian
e2f69828d2 [fix][fsdp2] clone sharded param so original full size shard can be gc'ed (#3597) [skip ci] 2026-04-11 20:22:35 -04:00
Wing Lian
122b50bad6 pre-cache the eot token ids rather than on each iteration (#3594) [skip ci] 2026-04-11 20:05:21 -04:00
Wing Lian
e77a185e86 upgrade transformers to use v5.5.3 (#3593) 2026-04-10 17:08:14 -04:00
Wing Lian
29fa4dedbb Gemma4 fixes and profiler (#3591) 2026-04-10 16:46:17 -04:00
Wing Lian
315cdeede9 handle trainable/masked spans in content and reasoning content (#3592) 2026-04-10 14:11:10 -04:00
NanoCode012
e7a6a5b529 fix: move warning after we've set any overrides (#3589) [skip ci] 2026-04-10 13:00:47 -04:00
NanoCode012
bfb4da1d25 fix: document jinja2 file path support (#3588) [skip ci] 2026-04-10 13:00:26 -04:00
floaty3
4dfa0a59b2 Add uninstall command to cut_cross_entropy import message (#3583) [skip ci] 2026-04-10 13:00:07 -04:00
Wing Lian
4ef608dda3 fix ddp/fsdp w gemma4 (#3584)
* fix ddp/fsdp w gemma4

* address pr comments

* activation offloading fix and update agent docs for gemma4
2026-04-09 20:02:36 -07:00
NanoCode012
7daf7d96f1 fix: regex for unfrozen language tower (#3586) [skip ci]
* fix: regex for unfrozen language tower

* fix: other leftover regex
2026-04-08 08:18:11 -07:00
Wing Lian
7c56809c7f use vllm 0.19.0 for torch 2.10.0 (#3582) 2026-04-07 08:09:49 -07:00
NanoCode012
149178ddb7 chore: cleanup post release v0.16 (#3577)
* fix: remove unneeded debug log

* fix: cleanup

* feat: add dense gemma config and cleanup

* feat: add cce support

* update notes and set torch compile

* fix patch for new number of return vals

* fixes for gemma4

* fix packing bug

* use updated cce for mm

* fix: pass in kv cache func when avail for transformers 5.5

* feat: update examples with flex variant and readme

* gemma4 lora attention kernels

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-06 10:10:52 -07:00
NanoCode012
dc638e723f fix(config): add cce and liger to nemotron-h example (#3573) [skip ci] 2026-04-06 10:10:25 -07:00
Wing Lian
6f15da4cac make it easier for agents to discover docs (#3579) [skip ci]
* make it easier for agents to discover docs

* fixup pr comments
2026-04-06 10:00:55 -07:00
Maxime
900eec7988 Fix DO_NOT_TRACK not being correctly handled (#3580)
* Fix DO_NOT_TRACK not being correctly handled

* add unit tests and lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-04 05:16:58 -04:00
82 changed files with 8714 additions and 461 deletions

View File

@@ -220,6 +220,16 @@ jobs:
run: |
axolotl --help
- name: Verify agent docs are discoverable
run: |
# Agent docs live in docs/agents/ (source of truth) and are resolved
# at runtime from the repo checkout or via `axolotl fetch docs`
axolotl agent-docs --list
axolotl agent-docs | grep -q "Fine-tuning framework"
axolotl agent-docs grpo | grep -q "GRPO"
axolotl agent-docs sft | grep -q "SFT"
python -c "from axolotl.cli.agent_docs import get_doc, list_topics; assert len(list_topics()) >= 5; assert 'GRPO' in get_doc('grpo')"
- name: Show HF cache
run: hf cache ls

View File

@@ -16,6 +16,9 @@ axolotl inference config.yaml # Interactive inference
axolotl merge-lora config.yaml # Merge LoRA adapter into base model
axolotl vllm-serve config.yaml # Start vLLM server for GRPO/EBFT training
axolotl fetch examples # Download example configs
axolotl agent-docs # Show agent-optimized docs (bundled with pip package)
axolotl agent-docs grpo # Topic-specific agent reference
axolotl config-schema # Dump config JSON schema
```
## Training Methods
@@ -35,6 +38,8 @@ Agent-specific references:
- [docs/agents/grpo.md](docs/agents/grpo.md) — GRPO online RL with reward functions
- [docs/agents/reward_modelling.md](docs/agents/reward_modelling.md) — outcome and process reward models
- [docs/agents/pretraining.md](docs/agents/pretraining.md) — continual pretraining
- [docs/agents/model_architectures.md](docs/agents/model_architectures.md) — model-specific quirks (Gemma4, Qwen3.5 MoE, etc.)
- [docs/agents/new_model_support.md](docs/agents/new_model_support.md) — debugging and adding support for new model architectures
## Config Pattern

View File

@@ -3,4 +3,6 @@ include README.md
include LICENSE
include src/setuptools_axolotl_dynamic_dependencies.py
include src/axolotl/utils/chat_templates/templates/*.jinja
include AGENTS.md
recursive-include docs/agents *.md
recursive-include axolotl *.py

View File

@@ -86,7 +86,7 @@ Features:
**Requirements**:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- Python >=3.11 (3.12 recommended)
- PyTorch ≥2.9.1
### Google Colab
@@ -95,6 +95,34 @@ Features:
### Installation
#### Using uv (recommended)
```bash
# install uv if you don't already have it installed
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
# CUDA 12.8.1 tends to have better package compatibility
export UV_TORCH_BACKEND=cu128
# create a new virtual environment
uv venv --python 3.12
source .venv/bin/activate
uv pip install torch==2.10.0 torchvision
uv pip install --no-build-isolation axolotl[deepspeed]
# recommended - install cut-cross-entropy
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
# (optional) - prefetch flash-attn2 and causal-conv1d kernels
uv run --python 3.12 python -c "from kernels import get_kernel; get_kernel('kernels-community/flash-attn2'); get_kernel('kernels-community/causal-conv1d')"
# Download example axolotl configs, deepspeed configs
axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL
```
#### Using pip
```bash
@@ -157,6 +185,29 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
- [API Reference](https://docs.axolotl.ai/docs/api/) - Auto-generated code documentation
- [FAQ](https://docs.axolotl.ai/docs/faq.html) - Frequently asked questions
## AI Agent Support
Axolotl ships with built-in documentation optimized for AI coding agents (Claude Code, Cursor, Copilot, etc.). These docs are bundled with the pip package — no repo clone needed.
```bash
# Show overview and available training methods
axolotl agent-docs
# Topic-specific references
axolotl agent-docs sft # supervised fine-tuning
axolotl agent-docs grpo # GRPO online RL
axolotl agent-docs preference_tuning # DPO, KTO, ORPO, SimPO
axolotl agent-docs reward_modelling # outcome and process reward models
axolotl agent-docs pretraining # continual pretraining
axolotl agent-docs --list # list all topics
# Dump config schema for programmatic use
axolotl config-schema
axolotl config-schema --field adapter
```
If you're working with the source repo, agent docs are also available at `docs/agents/` and the project overview is in `AGENTS.md`.
## 🤝 Getting Help
- Join our [Discord community](https://discord.gg/HhrNrHJPRb) for support

View File

@@ -0,0 +1,198 @@
# Model Architectures — Agent Reference
Model-specific quirks, required settings, and known issues. Check this before debugging training failures on specific model families.
## VLM (Vision Language Model) Quick Start
All VLM configs require these four lines:
```yaml
processor_type: AutoProcessor
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
```
Decision tree for VLM config:
```text
Is the model multimodal (has vision/audio encoder)?
├─ YES: Add `freeze_mm_modules: true` if training text only
│ Add `chat_template: <model_template>` (e.g. gemma4, qwen3_5, gemma3)
│ LoRA: use regex `lora_target_modules` to restrict to language model
└─ NO: Train as a regular text model
Is the model MoE (e.g. Gemma4 26B-A4B, Qwen3.5 35B-A3B)?
├─ YES: Add `lora_target_parameters` for expert LoRA
│ Consider ScatterMoE kernels (see Plugins section)
└─ NO: Standard LoRA config
```
## Plugins & Optimizations
### Cut Cross Entropy (CCE)
Computes loss from hidden states + lm_head weight without materializing the full logits tensor, saving significant VRAM. Install if not already present:
```bash
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
```
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
```
### ScatterMoE Kernels
Fuses expert + LoRA computation into a single kernel for MoE models. Significant speedup for models with many experts.
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
# Expert LoRA targets (3D parameter tensors, not nn.Linear):
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
Supported: Gemma4 (`gemma4_text`), Mixtral, Qwen MoE variants. The plugin auto-detects model type and routing function. Without ScatterMoE, expert LoRA still works but runs base expert matmul and LoRA as separate operations.
## Gemma 4
**Models**: `google/gemma-4-26B-A4B` (MoE), `google/gemma-4-31B` (dense), `google/gemma-4-E2B`, `google/gemma-4-E4B`
**Architecture**: Multimodal wrapper (`Gemma4ForConditionalGeneration`) over a text backbone (`Gemma4TextModel`), with optional vision/audio encoders. All Gemma4 HF repos have `model_type: "gemma4"` — even text-only variants load as multimodal with a vision tower.
### Required settings
```yaml
# Always needed for Gemma4:
freeze_mm_modules: true # Freeze vision/audio encoders for text-only training
gradient_checkpointing_kwargs:
use_reentrant: false # Shared per-layer norms cause "marked ready twice" with reentrant
# LoRA target — restrict to language model only (DO NOT use lora_target_linear: true):
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
```
### Auto-detection
Axolotl auto-detects Gemma4 and applies:
- `use_reentrant: false` for gradient checkpointing
- `ddp_find_unused_parameters: true` for DDP (skipped when `activation_offloading: true`)
### Multi-GPU
| Strategy | Works? | Notes |
|----------|--------|-------|
| DDP | Yes | Auto-sets `ddp_find_unused_parameters=True` |
| DDP + activation_offloading | Yes | `find_unused_parameters` is skipped (conflicts with checkpoint wrappers) |
| FSDP1 | No | OOM during dequantization/sharding with QLoRA |
| FSDP2 | Yes | Use `Gemma4TextDecoderLayer` (not `Gemma4DecoderLayer`) as wrap class |
| FSDP2 + activation_offloading | Yes | Lowest VRAM (~26 GiB/GPU for 26B-A4B) |
FSDP2 config:
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_version: 2
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer
```
### MoE (26B-A4B)
- `enable_moe_block: true`, 256 experts, top-k routing
- No separate `SparseMoeBlock` — MoE is embedded in each decoder layer
- Expert LoRA targets 3D parameter tensors:
```yaml
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
- ScatterMoE kernel acceleration:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
```
### VLM (Vision) Training
All Gemma4 models load as `Gemma4ForConditionalGeneration` with a vision tower. No custom `ProcessingStrategy` needed — the base class auto-detects the image token.
```yaml
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B
processor_type: AutoProcessor
freeze_mm_modules: true
chat_template: gemma4
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
```
A starting VLM loss of ~8-15 is typical. In most runs, loss converges below 1.0 within ~30-50 steps, though results may vary across configurations.
For the 26B-A4B MoE variant with ScatterMoE + expert LoRA + CCE, add:
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
### Common issues
| Symptom | Cause | Fix |
|---------|-------|-----|
| `mm_token_type_ids is required` in DDP | `model.config` not accessible through DDP wrapper | Already fixed — `unwrap_model()` in `compute_loss` and `prediction_step` |
| `marked a variable ready twice` in DDP | `ddp_find_unused_parameters=True` + activation_offloading checkpoint wrappers | Auto-handled — `find_unused_parameters` is skipped when `activation_offloading: true` |
| Loss ~12 instead of ~0.5 | Using `lora_target_linear: true` (applies LoRA to vision/audio modules) | Use the regex `lora_target_modules` pattern instead |
| FSDP2 `Could not find Gemma4AudioLayer` | Auto-wrap detects `_no_split_modules` including audio layers that don't exist | Explicitly set `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer` |
| `Gemma4ClippableLinear not supported` by PEFT | Vision tower uses a non-standard linear wrapper | Axolotl patches this automatically via `_patch_peft_clippable_linear()` |
### E2B/E4B dense models
These have `hidden_size_per_layer_input: 256` (per-layer input embeddings) and `attention_k_eq_v: False`. Known issue: loss starts higher than expected (~12 vs ~0.5 for 26B). Root cause under investigation — may be related to the per-layer input mechanism or the `Gemma4ForConditionalGeneration` loss computation.
## Gemma 3
**Models**: `google/gemma-3-*`
- `ddp_find_unused_parameters: true` needed (multimodal unused params)
- `use_reentrant: false` recommended
- Attention mask must be dropped for sample packing (handled automatically)
- Multi-GPU test currently skipped (`tests/e2e/multigpu/test_gemma3.py`)
## Qwen 3.5 MoE
**Models**: `Qwen/Qwen3.5-35B-A3B`
- Hybrid architecture: DeltaNet linear attention (30 layers) + full attention (10 layers)
- 256 experts, 8 active per token
- Known weight scale drift in late DeltaNet layers (36-38) due to AdamW + rare expert interaction
- Fix: `normalize_weight_scales` config to detect and rescale outliers:
```yaml
normalize_weight_scales:
- name_pattern: 'linear_attn\.conv1d\.weight'
threshold: 1.3
```
## General MoE Notes
- `lora_target_linear: true` with multimodal MoE models will apply LoRA to ALL linear modules including vision/audio encoders — use regex `lora_target_modules` to restrict to language model only
- Rare experts get larger effective learning rate from AdamW (small second-moment estimates) — can cause weight drift in recurrent/SSM components. Use `normalize_weight_scales` with `dry_run: true` to detect.
- For ScatterMoE kernel support, set `experts_implementation: scattermoe` and add the KernelsPlugin

View File

@@ -0,0 +1,181 @@
# New Model Support — Agent Reference
Guide for debugging and adding support for new model architectures in axolotl. Based on lessons learned from Gemma4, Gemma3, Qwen2-VL, and other multimodal/MoE models.
## Quick Validation Checklist
When testing a new model, run through these checks in order:
1. **Does the model load?** `axolotl preprocess config.yaml` — catches config schema errors
2. **Does LoRA apply?** Check for "Unsupported layer type" warnings from PEFT
3. **Is the initial loss sane?** First-step loss for a pretrained model should be 0.52.0 for SFT
4. **Does sample packing work?** Compare loss with `sample_packing: true` vs `false` — should be similar
5. **Is CCE active?** Check for "Applying Cut Cross Entropy" log and verify peak VRAM is lower
## Loss Debugging
### Expected initial loss
A pretrained model doing SFT should start with loss roughly in the 0.52.0 range. If loss starts above 3.0, something is wrong. If it's near `log(vocab_size)` (≈ 12 for 262K vocab), the model is predicting at random — attention masking or model weights are broken.
### Direct comparison technique
The fastest way to isolate a loss issue — bypass the trainer entirely:
```python
# Load model via axolotl's pipeline (applies all patches)
from axolotl.cli.config import load_cfg
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.loaders.model import ModelLoader
cfg = load_cfg("your_config.yaml")
normalize_config(cfg)
prepare_plugins(cfg)
tokenizer = load_tokenizer(cfg)
model, _ = ModelLoader(cfg, tokenizer).load()
# Forward pass on preprocessed data
model.train()
out = model(input_ids, labels=labels)
print(f"Direct loss: {out.loss.item()}") # Compare to trainer's reported loss
```
If direct loss is correct (~1.0) but trainer reports 34x higher, check `model_accepts_loss_kwargs` (see below).
### `model_accepts_loss_kwargs` inflation
HF Trainer checks if the model's `forward()` has `**kwargs` and sets `model_accepts_loss_kwargs=True`. This changes loss normalization: the trainer does NOT divide loss by `gradient_accumulation_steps` before logging. The gradient is correct — only the logged loss is inflated.
**Symptom**: Logged loss ≈ actual_loss × gradient_accumulation_steps.
**Which models are affected**: Any model with `**kwargs` in forward (common in multimodal models for extra inputs like `mm_token_type_ids`, `pixel_values`, etc.).
**Fix location**: `src/axolotl/core/trainers/base.py` `__init__()` — after `super().__init__()`, check if the unwrapped model actually has `num_items_in_batch` in its forward signature. If not, set `self.model_accepts_loss_kwargs = False`.
## Multimodal Models (ForConditionalGeneration)
Many recent models use `ForConditionalGeneration` as the top-level class, not `ForCausalLM`:
- Gemma3 → `Gemma3ForConditionalGeneration`
- Gemma4 → `Gemma4ForConditionalGeneration`
- Qwen2-VL → `Qwen2VLForConditionalGeneration`
- LLaVA → `LlavaForConditionalGeneration`
### Why this matters
| Component | Targets `ForCausalLM` | Needs `ForConditionalGeneration` |
|-----------|----------------------|--------------------------------|
| CCE patches | ✅ (default) | ❌ silently inactive if not patched |
| PEFT LoRA | ✅ | May fail on custom layer types |
| HF Trainer label handling | ✅ | May need extra inputs |
### Required extra inputs
Multimodal models require special inputs during training even for text-only data:
| Model | Required Input | Value for Text-Only |
|-------|---------------|-------------------|
| Gemma4 | `mm_token_type_ids` | `torch.zeros_like(input_ids)` |
| Gemma3 | `token_type_ids` | `torch.zeros_like(input_ids)` |
Auto-inject in `compute_loss()` when not provided by the data collator. See `core/trainers/base.py`.
### Custom layer types and PEFT
Vision towers often use custom module wrappers that PEFT doesn't support:
| Model | Custom Layer | Wraps | Fix |
|-------|-------------|-------|-----|
| Gemma4 | `Gemma4ClippableLinear` | `nn.Linear` | Redirect to `.linear` child |
Fix location: `src/axolotl/loaders/adapter.py` `_patch_peft_clippable_linear()`.
## Sample Packing
### How packed sequence detection works (transformers ≥ 5.x)
`transformers.masking_utils._preprocess_mask_arguments()` detects packed sequences from `position_ids` resets. But **only when `attention_mask is None`**:
```python
# From masking_utils.py:
if position_ids is not None and attention_mask is None and past_key_values is None:
packed_sequence_mask = find_packed_sequence_indices(position_ids)
```
If the collator provides an all-ones `attention_mask`, packing detection is **skipped** and the model builds a single causal mask spanning all packed sequences → cross-sequence attention leakage → very high loss.
### Fix for models using `create_causal_mask_mapping`
For Gemma3, Gemma4, and similar models that use the new transformers masking system, remove `attention_mask` from inputs when sample packing is active:
```python
# In compute_loss():
if (
self.args.sample_packing
and model_type in ("gemma4", "gemma3")
and "attention_mask" in inputs
and "position_ids" in inputs
):
del inputs["attention_mask"]
```
Fix location: `src/axolotl/core/trainers/base.py` `compute_loss()`.
### Models that DON'T need this fix
Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2, etc.) handle sample packing via axolotl's multipack attention monkeypatch instead. Only models using the new `create_causal_mask_mapping` / `create_causal_mask` masking system need the `attention_mask` removal.
## Attention Backend Selection
| Backend | Config | head_dim limit | torch_compile | Notes |
|---------|--------|---------------|---------------|-------|
| FA2 | `flash_attention: true` | 256 | ✅ | Fastest when supported |
| FA4 | auto with `flash_attention: true` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
| SDPA | `sdp_attention: true` | None | ✅ | Universal fallback |
| flex | `flex_attention: true` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
| eager | neither set | None | ✅ | Slowest, always works |
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.
**head_dim gotcha**: The 256 limit is specific to flash-attn CUDA kernels, NOT PyTorch-level. SDPA and flex_attention both handle arbitrary head_dim. Models with `global_head_dim > 256` (Gemma4: 512) must use SDPA or flex.
**flex + compile gotcha**: `torch_compile` with flex_attention can hit Triton shared memory OOM for large head_dim. Falls back to eager per-function (not a crash, but slower). Unsloth disables flex for Gemma4 for this reason.
## Cut Cross Entropy (CCE)
### How CCE patches work
CCE replaces the model's `forward()` with a fused version that computes loss from hidden states + lm_head weight without materializing the full logits tensor. This saves ~`batch × seq_len × vocab_size × dtype_bytes` of VRAM.
### Adding CCE for a new model
1. Check if the model type is in `cut_cross_entropy.transformers.patch.PATCH_FNS`
2. If not, axolotl's generic fallback (`integrations/cut_cross_entropy/__init__.py` `patch_llama_like()`) patches `{Prefix}ForCausalLM.forward` with `cce_forward`
3. For multimodal models (`ForConditionalGeneration`), a model-specific patch is needed in `ml-cross-entropy` repo
4. The multimodal `cce_forward` must accept all extra kwargs (pixel_values, mm_token_type_ids, etc.) and pop any that would conflict before calling `self.model()`
### Common CCE pitfall
If CCE appears active (log says "Applying Cut Cross Entropy") but peak VRAM doesn't decrease, check which class was patched. If the model loads as `ForConditionalGeneration` but CCE patched `ForCausalLM`, the patch is silently inactive.
## MoE Models
### Dense MLP vs MoE experts
Some MoE models (e.g., Gemma4) have BOTH dense MLP layers and MoE expert layers at every decoder layer:
- `gate_proj/up_proj/down_proj` → targets the **dense MLP** (`Gemma4TextMLP`)
- `experts.gate_up_proj/experts.down_proj` → targets the **MoE experts** (`Gemma4TextExperts`)
LoRA on the dense MLP works normally. Expert LoRA via `lora_target_parameters` requires PEFT support for the specific expert module type (may warn "Unsupported layer type").
### ScatterMoE kernels
`use_scattermoe: true` with `experts_implementation: scattermoe` registers fused expert kernels via transformers' `ExpertsInterface`. Significant speedup for MoE models. Requires the kernels plugin:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
```
## Where to Add Model-Specific Fixes
| What | Where | Example |
|------|-------|---------|
| Missing forward inputs | `core/trainers/base.py` `compute_loss()` | mm_token_type_ids injection |
| Attention mask fixes | `core/trainers/base.py` `compute_loss()` | Sample packing mask removal |
| Loss logging fixes | `core/trainers/base.py` `__init__()` | model_accepts_loss_kwargs override |
| PEFT/LoRA patches | `loaders/adapter.py` | ClippableLinear redirect |
| Attention patches | `monkeypatch/attention/` | FA4 tuple fix |
| Model-specific patches | `loaders/patch_manager.py` `_apply_model_specific_patches()` | Llama4, Kimi, NemotronH |
| CCE patches | `ml-cross-entropy` repo `transformers/` | Per-model cce_forward |
| Example configs | `examples/<model>/` | Validated YAML |
| Config validation | `utils/schemas/validation.py` | Compatibility checks |

View File

@@ -91,6 +91,30 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
| FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` |
| DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
## Profiling
To profile training and identify optimization opportunities:
```yaml
# Profile steps 3-7 (after warmup/autotuning settles)
profiler_steps_start: 3
profiler_steps: 5
```
This produces `profiler_trace.json` (Chrome trace) and `snapshot.pickle` (memory snapshot) in `output_dir`.
View the Chrome trace at `chrome://tracing`.
To programmatically inspect the trace:
```bash
python scripts/analyze_profile.py output_dir/
```
The trace shows per-kernel CUDA times, memory allocations, and operator-level breakdown. Look for:
- **Large matmul kernels**: candidates for fusion or quantization
- **Memory copies (H2D/D2H)**: unnecessary data movement
- **Small frequent kernels**: candidates for kernel fusion
- **Gaps between kernels**: pipeline bubbles from CPU overhead
Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd)
## File Map

View File

@@ -108,6 +108,14 @@ datasets:
type: chat_template
```
::: {.callout-tip}
`chat_template_jinja` also accepts a file path to a `.jinja2` file instead of an inline string:
```yaml
chat_template_jinja: ./path/to/my_template.jinja2
```
:::
::: {.callout-important}
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
:::
@@ -294,6 +302,113 @@ datasets:
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
:::
#### Content parts with per-part training control
Instead of using character offsets with `train_detail`, you can split a message's content into a list of parts, each with its own training flag. This is useful when you want to mask specific sections of a response (e.g., mask reasoning but train on the answer).
```{.json filename="data.jsonl"}
{
"messages": [
{"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me think step by step...", "train": false},
{"type": "text", "text": " The answer is 4.", "train": true}
]
}
]
}
```
The configuration is the same as standard `chat_template` — no extra fields needed:
```yaml
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
Each content part supports:
- `type`: `"text"` (required)
- `text`: the text value (also accepts `content` or `value` as the key)
- `train`: `true`/`false` (optional) — whether to train on this part
- `weight`: `0`/`1` (optional) — alternative to `train`
If a part has no `train` or `weight` flag, it inherits the turn-level training decision (from `roles_to_train`, `message_field_training`, or `train_on_inputs`).
::: {.callout-warning title="Whitespace at part boundaries"}
BPE tokenizers (used by Llama, Qwen, Mistral, GPT, etc.) prepend spaces to word tokens. For example, `" answer"` is a single token — the space is part of it. This means **where you place whitespace between content parts matters**:
**Split BEFORE spaces** (space goes with the next part):
```json
[
{"type": "text", "text": "Let me think...", "train": false},
{"type": "text", "text": " The answer is 4.", "train": true}
]
```
**DON'T put trailing spaces** on a part (the space merges with the next word into one token that straddles the boundary, and straddling tokens are masked):
```json
[
{"type": "text", "text": "Let me think... ", "train": false},
{"type": "text", "text": "The answer is 4.", "train": true}
]
```
In the bad example, `" The"` becomes a single token that spans both parts. Because it straddles the boundary, it is conservatively **masked** (not trained) — even though the second part has `train: true`.
**Newlines** typically merge with preceding punctuation (e.g., `":\n"` is one token). Keep newlines with the preceding part:
```json
[
{"type": "text", "text": "Thinking:\n", "train": false},
{"type": "text", "text": "The answer is 4.", "train": true}
]
```
Axolotl will log a warning if it detects trailing whitespace at a boundary between parts with different training flags.
:::
::: {.callout-note}
When all content parts in a message are strings, they are concatenated before being passed to the chat template. This means content parts work with **any** Jinja template — the template sees a plain string, and the per-part training flags are applied during tokenization.
:::
##### Per-part training on reasoning_content
For templates that support a separate `reasoning_content` field (e.g., `qwen3`), the same content-parts format works on `reasoning_content`. This is useful for masking incorrect reasoning steps while training on self-corrections:
```{.json filename="data.jsonl"}
{
"messages": [
{"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]},
{
"role": "assistant",
"reasoning_content": [
{"type": "text", "text": "Hmm maybe 2+2=5.", "train": false},
{"type": "text", "text": " Wait no, 2+2=4.", "train": true}
],
"content": [
{"type": "text", "text": "The answer is 4.", "train": true}
]
}
]
}
```
The `reasoning_content` and `content` fields are handled independently — each has its own token boundaries and per-part masking. No additional configuration is needed beyond what the template already requires.
::: {.callout-tip}
When `reasoning_content` is provided as a separate field, `split_thinking` is not needed — the reasoning is already separated from the content in the data.
:::
The same whitespace rules apply to `reasoning_content` parts as to `content` parts — split before spaces, keep newlines with the preceding part.
#### Reasoning split
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.

View File

@@ -8,6 +8,7 @@ format:
## Supported Models
- [Gemma-4](#sec-gemma-4) *(NEW)*
- [Mllama](#sec-mllama)
- [Llama4](#sec-llama4)
- [Pixtral](#sec-pixtral)
@@ -138,6 +139,40 @@ base_model: mistralai/Voxtral-Mini-3B-2507
processor_type: VoxtralProcessor
```
### Gemma-4 {#sec-gemma-4}
All Gemma 4 variants (E2B, E4B, 26B-A4B, 31B) load as multimodal models even for text-only training.
```yaml
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B, 31B
chat_template: gemma4
freeze_mm_modules: true # freeze vision/audio encoders for text-only or vision LoRA
# For the 26B-A4B MoE model, enable ScatterMoE and expert LoRA:
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
# MoE expert LoRA (3D tensors, not nn.Linear) — only for 26B-A4B:
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
::: {.callout-warning}
Gemma 4 VLM training starts with high loss (~8-15). This is expected — see the [training stability guide](training_stability.qmd) for details.
:::
::: {.callout-tip}
For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True`. However, when `activation_offloading: true`, `ddp_find_unused_parameters` is skipped (checkpoint wrappers conflict with it); use `freeze_mm_modules: true` instead to handle unused vision/audio params. For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`.
:::
### Gemma-3 {#sec-gemma-3}
::: {.callout-tip}

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
]
},
{

View File

@@ -26,8 +26,8 @@ output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
- ^model.language_model.*
- ^lm_head.*
adapter: qlora
lora_r: 32

View File

@@ -26,8 +26,8 @@ output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
- ^model.language_model.*
- ^lm_head.*
adapter: qlora
lora_r: 32

View File

@@ -22,8 +22,8 @@ output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
- ^model.language_model.*
- ^lm_head.*
adapter: qlora
lora_model_dir:

View File

@@ -1,19 +1,12 @@
# Gemma 4 26B-A4B MoE QLoRA with ScatterMoE kernels
#
# Validated: 50 steps on FineTome-100k, loss 7.4 -> 2.4, single RTX 5090 (32GB)
# Validated: 50 steps on FineTome-100k, loss 8.8 -> 1.8, single RTX 5090 (32GB)
# torch_compile=true: 21 GiB peak VRAM, ~230 tok/s, 336s total
#
# Key notes:
# - Flash Attention 2 is NOT supported (global_head_dim=512 > FA2 max of 256).
# Use sdp_attention instead.
# - Gemma 4 is multimodal (text+vision+audio). For text-only SFT, restrict
# LoRA to the text backbone via lora_target_linear_modules regex.
# - MoE experts use `experts_implementation: scattermoe` — Gemma 4 embeds MoE
# directly in the decoder layer (no SparseMoeBlock), so we register ScatterMoE
# via the transformers ExpertsInterface.
# - Expert LoRA targets are `experts.gate_up_proj` / `experts.down_proj`
# (no `mlp.` prefix, unlike Qwen/Mixtral).
# - micro_batch_size: 1 fits 2048 seq_len on 32GB GPU with SDP attention.
# Use micro_batch_size: 4 with 1024 seq_len, or on 48GB+ GPUs.
# - Max sequence length on 32GB GPU: 2048 (micro_batch_size=1, SDP attention).
# 4096 seq_len OOMs due to head_dim=512 math SDP materializing full score matrix.
# Use 48GB+ GPUs for longer sequences or multi-GPU with FSDP.
base_model: google/gemma-4-26B-A4B
@@ -24,7 +17,7 @@ plugins:
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
torch_compile: false
torch_compile: true
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
@@ -54,12 +47,9 @@ lora_r: 16
lora_alpha: 32
lora_dropout: 0
# Restrict LoRA to text backbone only (skip vision/audio encoders).
# lora_target_modules is intentionally empty — all module targeting is done
# via regex in lora_target_linear_modules below.
lora_target_modules: []
lora_target_linear_modules:
- language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj
# Restrict LoRA to text backbone only (skip vision/audio encoders)
# using regex to match only the text decoder attention projections.
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
# MoE expert LoRA (3D Parameter tensors, not nn.Linear)
lora_target_parameters:
@@ -73,7 +63,7 @@ lora_o_kernel: false
bnb_config_kwargs:
bnb_4bit_use_double_quant: true
wandb_project: gemma4-qlora
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
@@ -93,8 +83,7 @@ gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
# FA2 not supported — Gemma4 global_head_dim=512 exceeds FA2 max of 256
flash_attention: false
# FA2 not supported
sdp_attention: true
warmup_ratio: 0.1

View File

@@ -0,0 +1,71 @@
base_model: google/gemma-4-31B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.liger.LigerPlugin
torch_compile: true
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_rms_norm_gated: true
strict: false
chat_template: gemma4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:10%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.05
output_dir: ./outputs/gemma4-31b-qlora-flex
sequence_len: 2048
sample_packing: true
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
# Restrict LoRA to text backbone only (skip vision/audio encoders)
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
bnb_config_kwargs:
bnb_4bit_use_double_quant: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
# FA not supported
flex_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -0,0 +1,69 @@
base_model: google/gemma-4-31B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.liger.LigerPlugin
torch_compile: false
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_rms_norm_gated: true
strict: false
chat_template: gemma4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:10%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.05
output_dir: ./outputs/gemma4-31b-qlora
sequence_len: 2048
sample_packing: true
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
# Restrict LoRA to text backbone only (skip vision/audio encoders)
# using regex to match only the text decoder attention projections.
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
bnb_config_kwargs:
bnb_4bit_use_double_quant: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
# FA not supported
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

60
examples/gemma4/README.md Normal file
View File

@@ -0,0 +1,60 @@
# Finetune Google's Gemma 4 with Axolotl
[Gemma 4](https://huggingface.co/collections/google/gemma-4) is a family of multimodal models from Google. This guide covers how to train them with Axolotl.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
# 26B MoE QLoRA (1x80GB @ ~50 GiB)
axolotl train examples/gemma4/26b-a4b-moe-qlora.yaml
# 31B Dense QLoRA (1x80GB @ ~44 GiB)
axolotl train examples/gemma4/31b-qlora.yaml
# 31B Dense QLoRA Flex Attn (1x80GB @ ~26 GiB)
axolotl train examples/gemma4/31b-qlora-flex.yaml
```
### MoE Expert Quantization & Expert LoRA (26B-A4B only)
The 26B-A4B config uses ScatterMoE kernels via the transformers `ExpertsInterface` and quantizes expert weights on load. To learn about expert quantization, expert LoRA targeting, and related limitations, see the [MoE Expert Quantization](https://docs.axolotl.ai/docs/expert_quantization.html) docs.
## Flex Attention
Reduce ~40% VRAM (at the cost of up to half throughput) by setting the below (shown in `examples/gemma4/31b-qlora-flex.yaml`):
```yaml
torch_compile: true
flex_attention: true
```
This works for both the MoE and Dense model.
## Limitations
- **Flash Attention**: FA2 (max head_dim=256) and FA4 (max head_dim=128) cannot support Gemma 4's `global_head_dim=512`. Use SDP or flex attention instead.
- **LoRA kernels**: Not supported due to KV-sharing layers.
- **lora_target_linear**: Incompatible for multimodal models — use `lora_target_modules` with a regex to restrict LoRA to the text backbone.
### TIPS
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- You can run full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy and has not been tested.
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [Gemma 4 Blog](https://huggingface.co/blog/gemma4)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,62 @@
# Gemma 4 E2B Vision LoRA
#
# Fine-tuning LM LoRA adapters on multimodal Gemma4 with vision/multimodal modules frozen.
# Uses the base ProcessingStrategy (auto-detects image_token from processor).
base_model: google/gemma-4-E2B-it
processor_type: AutoProcessor
freeze_mm_modules: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
# Required for vision/multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: gemma4
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:100]
val_set_size: 0
output_dir: ./outputs/gemma4-e2b-vision-lora
adapter: lora
sequence_len: 2048
pad_to_sequence_len: false
lora_r: 16
lora_alpha: 32
lora_dropout: 0
# Target language model only — vision encoder is frozen via freeze_mm_modules
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
max_steps: 10
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
weight_decay: 0.0
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

View File

@@ -1,5 +1,15 @@
base_model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.liger.LigerPlugin
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_rms_norm_gated: true
# LoRA kernel patches are incompatible with this architecture — see README.
lora_mlp_kernel: false
lora_qkv_kernel: false
@@ -22,8 +32,6 @@ dataset_prepared_path: last_run_prepared
sequence_len: 4096
sample_packing: true
use_cut_cross_entropy: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
@@ -31,16 +39,16 @@ lora_r: 16
lora_alpha: 32
lora_dropout: 0.0
lora_target_modules:
# Attention projection layers (present in ~12 attention layers out of 88)
- q_proj
- k_proj
- v_proj
- o_proj
# To also train MoE expert weights, add them via lora_target_parameters
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
# lora_target_parameters:
# - up_proj
# - down_proj
# To also train MoE expert weights, add them via lora_target_parameters
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
# lora_target_parameters:
# - up_proj
# - down_proj
wandb_project:
wandb_entity:

View File

@@ -1,6 +1,16 @@
# See examples/nemotron-h/README.md for architecture notes and requirements.
base_model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.liger.LigerPlugin
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_rms_norm_gated: true
# LoRA kernel patches are incompatible with this architecture — see README.
lora_mlp_kernel: false
lora_qkv_kernel: false
@@ -23,8 +33,6 @@ dataset_prepared_path: last_run_prepared
sequence_len: 4096
sample_packing: true
use_cut_cross_entropy: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
@@ -36,11 +44,12 @@ lora_target_modules:
- k_proj
- v_proj
- o_proj
# To also train MoE expert weights, add them via lora_target_parameters
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
# lora_target_parameters:
# - up_proj
# - down_proj
# To also train MoE expert weights, add them via lora_target_parameters
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
# lora_target_parameters:
# - up_proj
# - down_proj
wandb_project:
wandb_entity:

View File

@@ -26,8 +26,8 @@ sample_packing: true
# Freeze vision encoder
unfrozen_parameters:
- model\.language_model\..*
- lm_head\..*
- model.language_model.*
- lm_head.*
wandb_project:
wandb_entity:

View File

@@ -0,0 +1,62 @@
# Qwen 3.5 35B-A3B MoE Vision LoRA
#
# Vision fine-tuning of the hybrid DeltaNet + Attention MoE model.
# 256 experts, 8 active per token, with early-fusion vision support.
base_model: Qwen/Qwen3.5-35B-A3B
processor_type: AutoProcessor
# Required for vision/multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: qwen3_5
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:100]
val_set_size: 0
output_dir: ./outputs/qwen35-35b-a3b-vision-lora
adapter: lora
sequence_len: 4096
pad_to_sequence_len: false
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
max_steps: 10
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
weight_decay: 0.0
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

View File

@@ -10,15 +10,15 @@ liger-kernel==0.7.0
packaging==26.0
huggingface_hub>=1.1.7
peft>=0.18.1
peft>=0.19.0,<0.20.0
tokenizers>=0.22.1
transformers==5.5.0
transformers==5.5.4
accelerate==1.13.0
datasets==4.5.0
datasets>=4.8.4,<4.9.0
deepspeed>=0.18.6,<0.19.0
trl==0.29.0
hf_xet==1.3.2
kernels==0.12.2
trl==1.1.0
hf_xet==1.4.3
kernels==0.13.0
fla-core==0.4.1
flash-linear-attention==0.4.1

1518
scripts/analyze_profile.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"'
)

View File

@@ -89,7 +89,7 @@ def parse_requirements(extras_require_map):
]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm>=0.17.1"]
extras_require_map["vllm"] = ["vllm>=0.19.0"]
elif (major, minor) >= (2, 9):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [

View File

@@ -0,0 +1,108 @@
"""Bundled agent documentation for axolotl.
These docs are optimized for consumption by AI coding agents.
The source of truth is docs/agents/*.md and AGENTS.md in the repo root.
This module resolves those paths at runtime — no files are duplicated
into the package.
For pip-only installs (no repo checkout), run `axolotl fetch docs` first
to download the docs locally.
"""
from pathlib import Path
# Topic name -> (filename in docs/agents/, fallback filename for AGENTS.md)
TOPICS = {
"overview": "AGENTS.md",
"sft": "docs/agents/sft.md",
"grpo": "docs/agents/grpo.md",
"preference_tuning": "docs/agents/preference_tuning.md",
"reward_modelling": "docs/agents/reward_modelling.md",
"pretraining": "docs/agents/pretraining.md",
"model_architectures": "docs/agents/model_architectures.md",
"new_model_support": "docs/agents/new_model_support.md",
}
def _find_repo_root() -> Path | None:
"""Walk up from this file to find the repo root (contains AGENTS.md)."""
# In an editable install or repo checkout, walk up from
# src/axolotl/cli/agent_docs/ to find the repo root
current = Path(__file__).resolve().parent
while current != current.parent:
if (current / "AGENTS.md").exists() and (current / "docs" / "agents").is_dir():
return current
current = current.parent
return None
def _find_docs_dir() -> Path | None:
"""Find a fetched docs directory (from `axolotl fetch docs`)."""
# axolotl fetch docs --dest defaults to ./docs/ in cwd
cwd_docs = Path.cwd() / "docs" / "agents"
if cwd_docs.is_dir():
return Path.cwd()
return None
def _resolve_path(topic: str) -> Path:
"""Resolve a topic name to the actual file path."""
if topic not in TOPICS:
available = ", ".join(sorted(TOPICS.keys()))
raise FileNotFoundError(f"Unknown topic: {topic!r}. Available: {available}")
relative_path = TOPICS[topic]
# Try repo root first (editable install / repo checkout)
repo_root = _find_repo_root()
if repo_root:
candidate = repo_root / relative_path
if candidate.exists():
return candidate
# Try cwd (fetched docs via `axolotl fetch docs`)
docs_root = _find_docs_dir()
if docs_root:
candidate = docs_root / relative_path
if candidate.exists():
return candidate
# Also check cwd directly for AGENTS.md
if topic == "overview":
cwd_agents = Path.cwd() / "AGENTS.md"
if cwd_agents.exists():
return cwd_agents
raise FileNotFoundError(
f"Could not find {relative_path!r}. "
f"If you installed axolotl via pip, run `axolotl fetch docs` first "
f"to download the documentation."
)
def get_doc(topic: str = "overview") -> str:
"""Return the content of an agent doc by topic name.
Args:
topic: One of the keys in TOPICS, or "overview" (default).
Returns:
The markdown content of the doc.
Raises:
FileNotFoundError: If the topic can't be found.
"""
return _resolve_path(topic).read_text()
def list_topics() -> dict[str, str]:
"""Return a dict of topic name -> first line (title) of each doc."""
result = {}
for topic in sorted(TOPICS.keys()):
try:
path = _resolve_path(topic)
first_line = path.read_text().split("\n", 1)[0].lstrip("# ").strip()
result[topic] = first_line
except FileNotFoundError:
result[topic] = "(not found — run `axolotl fetch docs`)"
return result

View File

@@ -294,7 +294,9 @@ def merge_lora(config: str, **kwargs):
@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.argument(
"directory", type=click.Choice(["examples", "deepspeed_configs", "docs"])
)
@click.option("--dest", help="Destination directory")
def fetch(directory: str, dest: Optional[str]):
"""
@@ -303,9 +305,10 @@ def fetch(directory: str, dest: Optional[str]):
Available directories:
- examples: Example configuration files
- deepspeed_configs: DeepSpeed configuration files
- docs: Full documentation (Quarto markdown files)
Args:
directory: One of `examples`, `deepspeed_configs`.
directory: One of `examples`, `deepspeed_configs`, `docs`.
dest: Optional destination directory.
"""
fetch_from_github(f"{directory}/", dest)
@@ -340,6 +343,112 @@ def delinearize_llama4(model: str, output: str):
do_delinearize_llama4(model, output)
@cli.command("agent-docs")
@click.argument("topic", required=False, default=None)
@click.option("--list", "list_topics", is_flag=True, help="List available topics")
def agent_docs(topic: Optional[str], list_topics: bool):
"""Show agent-optimized documentation.
Prints reference docs designed for AI coding agents.
These docs are bundled with the package — no network access needed.
\b
Examples:
axolotl agent-docs # overview (start here)
axolotl agent-docs grpo # GRPO reference
axolotl agent-docs sft # SFT reference
axolotl agent-docs --list # list all topics
"""
from axolotl.cli.agent_docs import get_doc, list_topics as _list_topics
if list_topics:
for name, title in _list_topics().items():
click.echo(f" {name:25s} {title}")
return
if topic is None:
topic = "overview"
try:
click.echo(get_doc(topic))
except FileNotFoundError as exc:
raise click.BadParameter(str(exc)) from exc
@cli.command("config-schema")
@click.option(
"--format",
"output_format",
type=click.Choice(["json", "yaml"]),
default="json",
help="Output format (default: json)",
)
@click.option("--field", help="Show schema for a specific field only")
def config_schema(output_format: str, field: Optional[str]):
"""Dump the full config JSON schema.
Useful for AI agents and tooling to discover all available config options,
their types, defaults, and descriptions.
\b
Examples:
axolotl config-schema # full JSON schema
axolotl config-schema --format yaml # YAML format
axolotl config-schema --field adapter # single field
"""
import json
try:
schema = AxolotlInputConfig.model_json_schema()
except (TypeError, ValueError, AttributeError) as exc:
# Fallback: dump field names, types, and defaults when full schema
# generation fails (e.g. torch.dtype not JSON-serializable)
LOG.warning(
"Full JSON schema generation failed, using simplified fallback: %s", exc
)
fields = {}
for name, field_info in AxolotlInputConfig.model_fields.items():
entry = {}
if field_info.description:
entry["description"] = field_info.description
if field_info.default is not None:
try:
json.dumps(field_info.default)
entry["default"] = field_info.default
except (TypeError, ValueError):
entry["default"] = str(field_info.default)
annotation = field_info.annotation
if annotation is not None:
entry["type"] = str(annotation)
fields[name] = entry
schema = {
"properties": fields,
"_note": "simplified schema (full generation failed)",
}
if field:
props = schema.get("properties", {})
if field not in props:
# Try case-insensitive match
matches = [k for k in props if k.lower() == field.lower()]
if matches:
field = matches[0]
else:
raise click.BadParameter(
f"Unknown field: {field!r}. "
f"Omit --field to dump the full schema, "
f"or pipe to jq: axolotl config-schema | jq '.properties | keys'"
)
schema = {field: props[field]}
if output_format == "yaml":
import yaml # pylint: disable=import-outside-toplevel
click.echo(yaml.dump(schema, default_flow_style=False, sort_keys=False))
else:
click.echo(json.dumps(schema, indent=2))
cli.add_command(lm_eval)

View File

@@ -115,6 +115,7 @@ def _do_merge_lora_efficient(*, cfg: DictDefault) -> None:
simulate_nf4_experts=simulate_nf4_experts,
nf4_blocksize=nf4_blocksize,
nf4_double_quant=nf4_double_quant,
trust_remote_code=bool(getattr(cfg, "trust_remote_code", False)),
)
LOG.debug("Memory-efficient LoRA merge completed successfully!")

View File

@@ -17,6 +17,93 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _build_layer_type_map(
base_model_path: Path, trust_remote_code: bool = False
) -> dict[str, str]:
"""Build a map of module_name -> layer_type using a meta-device model.
Instantiates the model architecture on the meta device (zero memory)
to inspect which modules are Linear vs Conv1d/Conv2d/Conv3d.
This avoids relying on weight tensor ndim heuristics.
"""
import json as _json
import torch.nn as nn
from transformers import AutoConfig
config_path = base_model_path / "config.json"
if not config_path.exists():
return {}
try:
with open(config_path) as f:
model_config = _json.load(f)
except (OSError, _json.JSONDecodeError):
return {}
architectures = model_config.get("architectures", [])
if not architectures:
return {}
try:
config = AutoConfig.from_pretrained(
str(base_model_path), trust_remote_code=trust_remote_code
)
except Exception:
LOG.debug("Could not load config for layer type introspection")
return {}
# Determine the right Auto class from architectures
from transformers import (
AutoModel,
AutoModelForCausalLM,
)
auto_classes = [AutoModelForCausalLM, AutoModel]
try:
from transformers import AutoModelForImageTextToText
auto_classes.insert(0, AutoModelForImageTextToText)
except ImportError:
pass
model = None
for auto_cls in auto_classes:
try:
with torch.device("meta"):
model = auto_cls.from_config(
config, trust_remote_code=trust_remote_code
)
break
except Exception: # noqa: BLE001
LOG.debug(
"Could not instantiate meta model with %s, trying next",
auto_cls.__name__,
)
if model is None:
LOG.debug("Could not instantiate meta model for layer type introspection")
return {}
layer_types = {}
for name, module in model.named_modules():
if isinstance(module, nn.Conv3d):
layer_types[name] = "Conv3d"
elif isinstance(module, nn.Conv2d):
layer_types[name] = "Conv2d"
elif isinstance(module, nn.Conv1d):
layer_types[name] = "Conv1d"
elif isinstance(module, nn.Linear):
layer_types[name] = "Linear"
del model
LOG.debug(
f"Layer type map: {len(layer_types)} modules "
f"({sum(1 for v in layer_types.values() if 'Conv' in v)} conv layers)"
)
return layer_types
def _simulate_nf4_roundtrip(
tensor: torch.Tensor,
blocksize: Optional[int] = None,
@@ -191,6 +278,7 @@ def _build_peft_layer_and_get_delta(
adapter_name: str = "default",
is_param_wrapper: bool = False,
magnitude: Optional[torch.Tensor] = None,
layer_type: Optional[str] = None,
) -> torch.Tensor:
"""
Use PEFT's own layer classes to compute the LoRA delta weight.
@@ -211,7 +299,7 @@ def _build_peft_layer_and_get_delta(
out_features = lora_b.shape[0]
lora_alpha = lora_config_dict.get("lora_alpha", lora_config_dict.get("r", 1))
use_rslora = bool(lora_config_dict.get("use_rslora", False))
use_dora = bool(lora_config_dict.get("use_dora", False)) and magnitude is not None
use_dora = bool(lora_config_dict.get("use_dora", False))
if is_param_wrapper:
from peft.tuners.lora.layer import ParamWrapper
@@ -227,18 +315,106 @@ def _build_peft_layer_and_get_delta(
"weight", nn.Parameter(base_tensor.clone(), requires_grad=False)
)
# ParamWrapper rejects dropout/fan_in_fan_out/lora_bias/use_dora, so
# build a minimal config with only the fields it accepts.
pw_config = LoraConfig(
r=r,
lora_alpha=lora_alpha,
lora_dropout=0.0,
fan_in_fan_out=False,
use_rslora=use_rslora,
use_dora=False,
lora_bias=False,
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
layer = ParamWrapper(
fake,
adapter_name=adapter_name,
parameter_name="weight",
config=pw_config,
r=r,
lora_alpha=lora_alpha,
use_rslora=use_rslora,
)
layer.lora_A[adapter_name].weight.data = lora_a
layer.lora_B[adapter_name].weight.data = lora_b
return layer.get_delta_weight(adapter_name)
elif (
layer_type and "Conv" in layer_type or (layer_type is None and lora_a.ndim > 2)
):
# Conv layer detected via model introspection (or ndim fallback)
from peft.tuners.lora import layer as peft_lora_layer
# Determine conv type from layer_type map or fall back to ndim
if layer_type and "Conv" in layer_type:
conv_type: str = layer_type
else:
ndim = lora_a.ndim
_conv_map = {3: "Conv1d", 4: "Conv2d", 5: "Conv3d"}
if ndim not in _conv_map:
raise ValueError(
f"Unsupported LoRA weight dimensionality {ndim} for conv layer"
)
conv_type = _conv_map[ndim]
LOG.warning(
f"Using ndim-based fallback for conv detection (ndim={ndim}). "
f"Consider providing layer_type from meta-device introspection."
)
conv_cls_map = {"Conv1d": nn.Conv1d, "Conv2d": nn.Conv2d, "Conv3d": nn.Conv3d}
ConvCls = conv_cls_map[conv_type]
PeftConvCls = getattr(peft_lora_layer, conv_type)
# Reconstruct conv parameters from base tensor and lora_a shapes
# base_tensor: [out_channels, in_channels/groups, *kernel_size]
# lora_a: [r, in_channels/groups, *kernel_size]
# lora_b: [out_channels, r, *ones]
out_channels = base_tensor.shape[0]
in_channels = base_tensor.shape[1]
kernel_size = tuple(base_tensor.shape[2:])
stride = (1,) * (base_tensor.ndim - 2)
padding = (0,) * (base_tensor.ndim - 2)
base_layer = ConvCls(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=False,
)
base_layer.weight.data = base_tensor.clone()
conv_config = LoraConfig(
r=r_total,
lora_alpha=lora_alpha,
use_rslora=use_rslora,
use_dora=use_dora,
)
layer = PeftConvCls(
base_layer,
adapter_name=adapter_name,
config=conv_config,
r=r_total,
lora_alpha=lora_alpha,
)
layer.lora_A[adapter_name].weight.data = lora_a
layer.lora_B[adapter_name].weight.data = lora_b
if use_dora:
if magnitude is None:
raise ValueError(
f"DoRA merge requires a magnitude vector but none was found "
f"for conv layer (adapter={adapter_name}). Check that the "
f"adapter checkpoint contains lora_magnitude_vector weights."
)
mag_layer = layer.lora_magnitude_vector[adapter_name]
mag_layer.weight = nn.Parameter(magnitude)
layer.merge(adapter_names=[adapter_name])
return base_layer.weight.data - base_tensor
return layer.get_delta_weight(adapter_name)
else:
from peft.tuners.lora.layer import Linear as LoraLinear
@@ -251,15 +427,20 @@ def _build_peft_layer_and_get_delta(
or lora_config_dict.get("lora_fan_in_fan_out", False)
)
layer = LoraLinear(
base_layer,
adapter_name=adapter_name,
linear_config = LoraConfig(
r=r_total,
lora_alpha=lora_alpha,
fan_in_fan_out=fan_in_fan_out,
use_rslora=use_rslora,
use_dora=use_dora,
)
layer = LoraLinear(
base_layer,
adapter_name=adapter_name,
config=linear_config,
r=r_total,
lora_alpha=lora_alpha,
)
layer.lora_A[adapter_name].weight.data = lora_a
layer.lora_B[adapter_name].weight.data = lora_b
@@ -267,6 +448,12 @@ def _build_peft_layer_and_get_delta(
# DoRA merges magnitude normalization into the weight directly.
# Use PEFT's merge() which handles DoRA internally, then
# compute the delta as merged_weight - original_weight.
if magnitude is None:
raise ValueError(
f"DoRA merge requires a magnitude vector but none was found "
f"for linear layer (adapter={adapter_name}). Check that the "
f"adapter checkpoint contains lora_magnitude_vector weights."
)
mag_layer = layer.lora_magnitude_vector[adapter_name]
mag_layer.weight = nn.Parameter(magnitude)
layer.merge(adapter_names=[adapter_name])
@@ -382,6 +569,7 @@ def _merge_tensor_with_lora(
nf4_double_quant: bool = True,
use_dora: bool = False,
weight_renamings: Optional[Dict[str, str]] = None,
layer_type_map: Optional[Dict[str, str]] = None,
) -> tuple[torch.Tensor, bool]:
"""
Helper function to merge a single tensor with its corresponding LoRA weights.
@@ -426,12 +614,30 @@ def _merge_tensor_with_lora(
if use_dora
else None
)
# Look up layer type from meta-device model introspection
_layer_type = None
if layer_type_map:
mod_path = key.rsplit(".weight", 1)[0] if key.endswith(".weight") else key
_layer_type = layer_type_map.get(mod_path)
# Try common prefix variations (e.g. with/without "model." prefix)
if _layer_type is None:
for prefix in [
"model.",
"model.language_model.",
"model.language_model.model.",
]:
_layer_type = layer_type_map.get(prefix + mod_path)
if _layer_type:
break
delta = _build_peft_layer_and_get_delta(
lora_a.to(device),
lora_b.to(device),
lora_config_dict,
tensor.to(device),
magnitude=magnitude.to(device) if magnitude is not None else None,
layer_type=_layer_type,
)
merged_tensor = (
(tensor.to(device).to(torch.float32) + delta.to(torch.float32))
@@ -556,6 +762,7 @@ def _fuse_and_unfuse_with_merge(
nf4_double_quant: bool = True,
use_dora: bool = False,
weight_renamings: Optional[Dict[str, str]] = None,
layer_type_map: Optional[Dict[str, str]] = None,
) -> tuple[Dict[str, torch.Tensor], int, set]:
"""
For tensors matching WeightConverter patterns (MoE expert weights):
@@ -696,12 +903,32 @@ def _fuse_and_unfuse_with_merge(
if use_dora
else None
)
# Look up layer type for the fused key
_layer_type = None
if layer_type_map:
mod_path = (
fused_key.rsplit(".weight", 1)[0]
if fused_key.endswith(".weight")
else fused_key
)
_layer_type = layer_type_map.get(mod_path)
if _layer_type is None:
for prefix in [
"model.",
"model.language_model.",
"model.language_model.model.",
]:
_layer_type = layer_type_map.get(prefix + mod_path)
if _layer_type:
break
delta = _build_peft_layer_and_get_delta(
lora_a.to(device),
lora_b.to(device),
lora_config_dict,
fused_tensor.to(device),
magnitude=magnitude.to(device) if magnitude is not None else None,
layer_type=_layer_type,
)
fused_tensor = (
(
@@ -740,6 +967,7 @@ def merge_lora_sharded_efficient(
simulate_nf4_experts: bool = False,
nf4_blocksize: Optional[int] = None,
nf4_double_quant: bool = True,
trust_remote_code: bool = False,
) -> None:
"""
Memory-efficient LoRA merging that processes shards individually
@@ -750,6 +978,8 @@ def merge_lora_sharded_efficient(
simulate_nf4_experts: Apply NF4 roundtrip only to MoE expert tensors
(for quantize_moe_experts). Expert tensors are identified by having
"expert" in the key name and ndim >= 3.
trust_remote_code: Whether to trust remote code when loading model
config for layer-type introspection. Defaults to False for safety.
"""
base_model_path = Path(base_model_path)
lora_adapter_path = Path(lora_adapter_path)
@@ -780,6 +1010,10 @@ def merge_lora_sharded_efficient(
use_dora = bool(lora_config_dict.get("use_dora", False))
# Build layer type map via meta-device model introspection
layer_type_map = _build_layer_type_map(
base_model_path, trust_remote_code=trust_remote_code
)
unsupported_methods = []
# Check for AdaLoRA (Adaptive LoRA)
@@ -904,6 +1138,7 @@ def merge_lora_sharded_efficient(
nf4_double_quant=nf4_double_quant,
use_dora=use_dora,
weight_renamings=weight_renamings,
layer_type_map=layer_type_map,
)
merged_count += fused_merged
@@ -926,6 +1161,7 @@ def merge_lora_sharded_efficient(
nf4_double_quant=nf4_double_quant,
use_dora=use_dora,
weight_renamings=weight_renamings,
layer_type_map=layer_type_map,
)
merged_tensors[key] = merged_tensor
if was_merged:

View File

@@ -41,6 +41,7 @@ from axolotl.utils.callbacks import (
GCCallback,
SaveAxolotlConfigtoWandBCallback,
SaveModelOnFirstStepCallback,
SkipEvalOnResumeCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.distributed import build_parallelism_config
@@ -118,6 +119,9 @@ class TrainerBuilderBase(abc.ABC):
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.resume_from_checkpoint:
callbacks.append(SkipEvalOnResumeCallback())
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))

View File

@@ -100,6 +100,27 @@ class AxolotlTrainer(
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
# Gemma4 (and similar multimodal models) declare **kwargs in forward() for
# extra inputs like mm_token_type_ids. HF Trainer interprets VAR_KEYWORD as
# "the model handles num_items_in_batch internally" and skips the loss ÷
# gradient_accumulation_steps normalisation, which inflates the *logged* loss
# (the gradient itself is still correct). Override to False when the model
# doesn't actually consume num_items_in_batch.
if self.model_accepts_loss_kwargs:
model_to_check = self.accelerator.unwrap_model(self.model)
if hasattr(model_to_check, "base_model"): # PEFT wrapper
model_to_check = model_to_check.base_model
if hasattr(model_to_check, "model"):
model_to_check = model_to_check.model
fwd = getattr(model_to_check, "forward", None)
if fwd is not None:
import inspect
params = inspect.signature(fwd).parameters
if "num_items_in_batch" not in params:
self.model_accepts_loss_kwargs = False
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
@@ -383,13 +404,29 @@ class AxolotlTrainer(
# Gemma4 requires mm_token_type_ids during training (even for text-only).
# Inject zeros (= text token type) when not provided by the data collator.
# Use unwrap_model to handle DDP/FSDP wrappers that don't proxy .config.
_unwrapped = self.accelerator.unwrap_model(model)
_model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None)
if (
"mm_token_type_ids" not in inputs
and "input_ids" in inputs
and getattr(getattr(model, "config", None), "model_type", None) == "gemma4"
and _model_type == "gemma4"
):
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
# Gemma4 (and Gemma3): transformers' masking_utils detects packed sequences
# from position_ids, but only when attention_mask is None. When sample
# packing is active the collator provides an all-ones attention_mask that
# prevents this detection — remove it so the model builds the correct
# per-sequence causal masks.
if (
self.args.sample_packing
and _model_type in ("gemma4", "gemma3")
and "attention_mask" in inputs
and "position_ids" in inputs
):
del inputs["attention_mask"]
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
@@ -398,6 +435,23 @@ class AxolotlTrainer(
num_items_in_batch=num_items_in_batch,
)
# Gemma4ForConditionalGeneration computes loss with a manual
# nn.CrossEntropyLoss() that bypasses proper num_items_in_batch
# normalization and does redundant attention_mask filtering.
# Compute loss externally using the standard loss_function instead.
if _model_type == "gemma4" and "labels" in inputs:
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
unwrapped = self.accelerator.unwrap_model(model)
vocab_size = unwrapped.config.get_text_config().vocab_size
loss = unwrapped.loss_function(
logits, labels, vocab_size, num_items_in_batch=num_items_in_batch
)
if return_outputs:
return loss, outputs
return loss
return super().compute_loss(
model,
inputs,
@@ -410,6 +464,21 @@ class AxolotlTrainer(
LOG.info("Running evaluation step...")
return super().evaluate(*args, **kwargs)
@override
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
# Gemma4 requires mm_token_type_ids even during evaluation.
_unwrapped = self.accelerator.unwrap_model(model)
_model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None)
if (
"mm_token_type_ids" not in inputs
and "input_ids" in inputs
and _model_type == "gemma4"
):
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
return super().prediction_step(
model, inputs, prediction_loss_only, ignore_keys=ignore_keys
)
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {}

View File

@@ -242,6 +242,85 @@ class ProducerConfig:
)
class _GroupShardedSampler:
"""Rank-aware shard of a ``RepeatSampler`` that preserves GRPO groups.
``RepeatSampler`` yields ``num_generations`` consecutive copies of
each prompt, forming a GRPO group. For distributed training each
rank must see a disjoint slice of prompts (otherwise every rank
dogpiles on the first 1/world_size of the batch) while keeping each
group intact on a single rank so advantage normalization sees all
peer generations.
``accelerator.prepare(DataLoader)`` does not handle this correctly
for custom samplers with ``split_batches=False`` (the default): it
leaves the sampler alone and every rank replays identical indices.
This wrapper fixes that by consuming the inner sampler's full
output, chunking it into ``num_generations``-sized groups, and
round-robining whole groups across ranks.
Intended to be used ONLY when distributed training is active
(``num_replicas > 1``); for single-rank it is a no-op but still
correct.
"""
def __init__(
self,
inner: Any,
num_generations: int,
rank: int,
num_replicas: int,
):
if num_generations < 1:
raise ValueError(f"num_generations must be >= 1, got {num_generations}")
if num_replicas < 1:
raise ValueError(f"num_replicas must be >= 1, got {num_replicas}")
if not (0 <= rank < num_replicas):
raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}")
self.inner = inner
self.num_generations = num_generations
self.rank = rank
self.num_replicas = num_replicas
def __iter__(self):
all_indices = list(self.inner)
if len(all_indices) % self.num_generations != 0:
raise ValueError(
f"inner sampler yielded {len(all_indices)} indices, "
f"not a multiple of num_generations={self.num_generations}"
)
# Chunk the flat index sequence into groups of num_generations
# consecutive indices. ``RepeatSampler`` guarantees that each
# group contains num_generations copies of the same prompt id.
groups = [
all_indices[i : i + self.num_generations]
for i in range(0, len(all_indices), self.num_generations)
]
# Round-robin whole groups across ranks. Round-robin (vs.
# contiguous chunking) preserves approximate shuffled order on
# each rank even when the group count is small relative to the
# world size.
for group in groups[self.rank :: self.num_replicas]:
yield from group
def __len__(self):
try:
inner_len = len(self.inner)
except TypeError:
# Non-sized inner sampler — we can't know the per-rank
# length without materializing. Return 0 as a hint that the
# DataLoader should fall back to iteration.
return 0
total_groups = inner_len // self.num_generations
# Ceiling division for the trailing groups that don't divide
# evenly — extra groups go to the first ``total_groups %
# num_replicas`` ranks, matching the round-robin above.
my_groups = (
total_groups + self.num_replicas - self.rank - 1
) // self.num_replicas
return my_groups * self.num_generations
class DataProducer(ABC):
"""Abstract base class for online data producers.
@@ -556,6 +635,34 @@ class GRPODataProducer(BaseDataProducer):
seed=self._seed,
)
# Shard the sampler across distributed ranks so each rank sees
# a disjoint slice of prompts. ``RepeatSampler`` groups each
# prompt with ``num_generations`` consecutive copies — our
# wrapper round-robins WHOLE groups across ranks so all
# generations of a given prompt stay on the same rank (needed
# for GRPO advantage normalization within a group).
#
# Without this, ``accelerator.prepare(dl)`` with the default
# ``split_batches=False`` leaves the custom sampler alone, so
# every rank iterates the identical index sequence and the
# cluster dogpiles on the first 1/world_size of the prompts.
num_replicas = max(1, trainer.accelerator.num_processes)
if num_replicas > 1:
sampler = _GroupShardedSampler(
inner=sampler,
num_generations=self._num_generations,
rank=trainer.accelerator.process_index,
num_replicas=num_replicas,
)
logger.info(
"[RANK:%d] _GroupShardedSampler active "
"(num_replicas=%d, num_generations=%d, gen_batch=%d)",
trainer.accelerator.process_index,
num_replicas,
self._num_generations,
self._generation_batch_size,
)
# Use identity collator (same as stock GRPOTrainer)
def _identity(x):
return x
@@ -574,12 +681,11 @@ class GRPODataProducer(BaseDataProducer):
rank=trainer.args.process_index,
),
)
self._prompt_dl = trainer.accelerator.prepare(dl)
# Don't let accelerator track this dataloader
acc_dls = trainer.accelerator._dataloaders
if self._prompt_dl in acc_dls:
acc_dls.remove(self._prompt_dl)
# Skip accelerator.prepare — we're handling per-rank sharding
# ourselves via ``_GroupShardedSampler``. ``prepare()`` would
# otherwise try to wrap the DataLoader with its own sharding
# logic which does not understand our group structure.
self._prompt_dl = dl
self._prompt_iter = iter(self._prompt_dl)
@@ -1103,11 +1209,22 @@ class AsyncGRPOTrainer(GRPOTrainer):
- vllm_lora_sync: saves adapter to filesystem, vLLM loads natively
- PEFT no-merge: computes merged weights as new tensors, NCCL broadcast
- Non-PEFT: stock sync_weights via merge_adapter + NCCL
This is the canonical sync trigger and runs in BOTH async and
synchronous modes from ``_prepare_inputs_with_data_producer`` /
``_prepare_inputs_legacy_async``. The ``_generate_single_turn``
patch is a parallel backup for non-data-producer paths (vanilla
GRPO without NeMo Gym), where the data producer is bypassed
entirely and TRL's stock generate-then-sync flow is used instead.
"""
if not (self.use_vllm and self.args.async_prefetch):
if not self.use_vllm:
return
step = self.state.global_step
interval = self.args.vllm_sync_interval
# Default to syncing every step when no interval is configured —
# otherwise ``step % None`` would TypeError, and the previous
# behavior of crashing on the first sync was strictly worse than
# the standard "sync every optimizer step".
interval = self.args.vllm_sync_interval or 1
if step != self._last_synced_step and step % interval == 0:
if step == 0:
logger.info("Skipping vLLM weight sync at step 0 (no training yet)")
@@ -1202,13 +1319,42 @@ class AsyncGRPOTrainer(GRPOTrainer):
# Permanently replace vllm_generation.sync_weights with our custom
# sync to avoid merge_adapter (fails on FP8 / races with training).
# For LoRA sync mode, make it a no-op here since _maybe_sync_vllm_weights
# handles the sync with proper interval tracking.
#
# The design has two modes that have to be threaded carefully:
#
# - Async prefetch ON: BG generation thread can't safely call
# sync_weights mid-rollout (it races with the trainer's optimizer
# step and can corrupt weights). We no-op the stock sync hook and
# drive sync ourselves from ``_maybe_sync_vllm_weights`` after the
# optimizer step on the main thread.
#
# - Async prefetch OFF (synchronous mode): TRL's stock
# ``_generate_single_turn`` calls ``sync_weights`` once per step
# boundary. There's no BG thread to race with, and
# ``_maybe_sync_vllm_weights`` short-circuits with
# ``if not async_prefetch: return``, so we MUST wire the stock
# hook directly to our LoRA sync helper — otherwise nothing ever
# pushes weights to vLLM and the trainer becomes a no-op (vLLM
# keeps serving the base model, every rollout in every group
# produces identical outputs, advantages are zero, optimizer
# step gets skipped, repeat).
if not getattr(self, "_patched_sync_weights", False):
if self.use_vllm and hasattr(self, "vllm_generation"):
if getattr(self.args, "vllm_lora_sync", False):
# No-op: LoRA sync is driven by _maybe_sync_vllm_weights
self.vllm_generation.sync_weights = lambda: None
if getattr(self.args, "async_prefetch", False):
# Async: drive sync from main thread via
# _maybe_sync_vllm_weights instead.
self.vllm_generation.sync_weights = lambda: None
else:
# Sync mode: TRL's _generate_single_turn already
# calls sync_weights once per step boundary. Wire
# it directly to our LoRA filesystem sync helper.
sync_helper = self._sync_lora_adapter
def _lora_filesystem_sync():
sync_helper()
self.vllm_generation.sync_weights = _lora_filesystem_sync
self._patched_sync_weights = True
else:
from accelerate.utils import is_peft_model

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"
```
## Usage
@@ -44,6 +44,7 @@ plugins:
- gemma3_text
- gemma3n
- gemma3n_text
- gemma4
- glm
- glm4
- glm4_moe

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`'
'`pip uninstall -y cut-cross-entropy && pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"`'
)

View File

@@ -146,10 +146,6 @@ Gemma 4 (e.g. `google/gemma-4-26B-A4B`) has a unique hybrid MoE architecture:
Because there is no SparseMoeBlock class to patch, Gemma 4 uses a different integration path: we register `"scattermoe"` as a custom implementation in the transformers `ExpertsInterface`, and set `experts_implementation: scattermoe` in the config. The `@use_experts_implementation` decorator on `Gemma4TextExperts` then dispatches to our ScatterMoE kernel automatically. The router is untouched — it runs as-is.
**Important limitations:**
- **Flash Attention 2 is not supported** — Gemma 4 uses `global_head_dim: 512` for full attention layers, which exceeds FA2's maximum head dimension of 256. Use `sdp_attention: true` instead.
- **Multimodal model**: Gemma 4 includes vision and audio encoders. For text-only SFT, use `lora_target_linear_modules` with a regex to restrict LoRA to the text backbone (e.g. `language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj`).
## Limitations
- **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`).

View File

@@ -53,28 +53,6 @@ class KernelsArgs(BaseModel):
return data
@model_validator(mode="before")
@classmethod
def warn_sonicmoe_lora_overhead(cls, data):
if data.get("use_sonicmoe") is True and data.get("adapter") in (
"lora",
"qlora",
):
lora_target = data.get("lora_target_modules") or []
lora_linear = data.get("lora_target_linear_modules") or []
targets = (
lora_target if isinstance(lora_target, list) else [lora_target]
) + (lora_linear if isinstance(lora_linear, list) else [lora_linear])
expert_keywords = ("gate_up_proj", "down_proj", "experts")
if any(kw in t for t in targets for kw in expert_keywords):
LOG.info(
"SonicMoE + LoRA on expert modules uses runtime weight materialization "
"(W_eff = W + scaling*B@A per forward). This has slightly higher overhead "
"than ScatterMoE's fused Triton LoRA kernels but works with any CUTLASS kernel."
)
return data
@model_validator(mode="before")
@classmethod
def disable_mlp_kernel(cls, data):

View File

@@ -2,17 +2,35 @@
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
from . import layers
from .lora_ops import ParallelExperts
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
from .lora_layout import (
peft_down_proj_lora_to_scattermoe,
peft_lora_B_to_scattermoe,
peft_lora_to_scattermoe,
validate_scattermoe_lora_shapes,
)
__all__ = [
"layers",
"ParallelExperts",
"flatten_sort_count",
"parallel_linear",
"ScatterMoELoRA",
"parallel_linear_lora",
"lora_ops",
"peft_down_proj_lora_to_scattermoe",
"peft_lora_B_to_scattermoe",
"peft_lora_to_scattermoe",
"validate_scattermoe_lora_shapes",
]
try:
from . import layers
from .lora_ops import ParallelExperts
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
except ModuleNotFoundError as exc:
if exc.name != "triton":
raise
else:
__all__ += [
"layers",
"ParallelExperts",
"flatten_sort_count",
"parallel_linear",
"ScatterMoELoRA",
"parallel_linear_lora",
"lora_ops",
]

View File

@@ -35,81 +35,19 @@ import torch
from torch import nn
from torch.nn import functional as F
from .lora_layout import (
peft_down_proj_lora_to_scattermoe,
peft_lora_B_to_scattermoe,
peft_lora_to_scattermoe,
)
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora
# =============================================================================
# LoRA layout conversion utilities (peft <-> scattermoe)
# =============================================================================
def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
"""Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe
expert-major ``[N, r*E]``.
peft reshapes B to ``[out, r, E]`` (rank-major).
scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major).
"""
N = peft_B.shape[0]
return (
peft_B.reshape(N, rank, num_experts)
.permute(0, 2, 1)
.contiguous()
.reshape(N, num_experts * rank)
)
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Convert peft LoRA weights to scattermoe layout (with A<->B swap).
peft operates on the parameter in its native storage layout ``[E, dim1, dim2]``
where ``in_features=dim1, out_features=dim2``. ScatterMoE transposes the
parameter (``W = param.transpose(2, 1)``) giving ``[E, dim2, dim1]`` with
``K=dim2, N=dim1``. Because of this transposition, peft's A and B roles
are swapped relative to scattermoe's convention.
peft gives:
lora_A ``[r*E, dim1]``, lora_B ``[dim2, r*E]``
scattermoe needs:
lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]``
This function swaps A<->B and converts B from rank-major to expert-major.
Uses vectorized tensor operations (no Python loop over experts).
Works for **both** gate_up_proj and down_proj since the transposition
issue is the same for any parameter.
"""
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
dim1 = peft_A.shape[1] # peft in_features -> scattermoe N
dim2 = peft_B_em.shape[0] # peft out_features -> scattermoe K
# smoe_A: per expert, transpose B_e [dim2, r] -> [r, dim2]
# [dim2, E*r] -> [dim2, E, r] -> [E, r, dim2] -> [E*r, dim2]
smoe_A = (
peft_B_em.reshape(dim2, num_experts, rank)
.permute(1, 2, 0)
.contiguous()
.reshape(rank * num_experts, dim2)
)
# smoe_B: per expert, transpose A_e [r, dim1] -> [dim1, r]
# [E*r, dim1] -> [E, r, dim1] -> [dim1, E, r] -> [dim1, E*r]
smoe_B = (
peft_A.reshape(num_experts, rank, dim1)
.permute(2, 0, 1)
.contiguous()
.reshape(dim1, num_experts * rank)
)
return smoe_A, smoe_B
def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Deprecated alias for :func:`peft_lora_to_scattermoe`."""
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
__all__ = [
"peft_down_proj_lora_to_scattermoe",
"peft_lora_B_to_scattermoe",
"peft_lora_to_scattermoe",
]
# =============================================================================
# ParamWrapper unwrapping
@@ -199,7 +137,7 @@ def _unwrap_experts_lora(experts_module):
if gup is not None:
num_experts = gup.shape[0]
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
# Extract gate_up_proj LoRA
gup_lora = None
gup_wrapper = wrappers.get("gate_up_proj")
if gup_wrapper is not None:
@@ -208,7 +146,7 @@ def _unwrap_experts_lora(experts_module):
rank = lora_A.shape[0] // num_experts
gup_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)
# Extract down_proj LoRA (needs A<->B swap due to transposition)
# Extract down_proj LoRA
down_lora = None
down_wrapper = wrappers.get("down_proj")
if down_wrapper is not None:

View File

@@ -0,0 +1,78 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Pure tensor layout helpers for ScatterMoE LoRA weights."""
def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
"""Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe
expert-major ``[N, r*E]``.
peft reshapes B to ``[out, r, E]`` (rank-major).
scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major).
"""
N = peft_B.shape[0]
return (
peft_B.reshape(N, rank, num_experts)
.permute(0, 2, 1)
.contiguous()
.reshape(N, num_experts * rank)
)
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Convert peft LoRA weights to scattermoe layout.
peft operates on the parameter in its native storage layout ``[E, dim1, dim2]``
where ``out_features=dim1, in_features=dim2``. ScatterMoE transposes the
parameter (``W = param.transpose(2, 1)``), giving ``[E, dim2, dim1]`` with
``K=dim2, N=dim1``.
peft gives:
lora_A ``[r*E, dim2]``, lora_B ``[dim1, r*E]``
scattermoe needs:
lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]``
peft's A already matches ScatterMoE's A shape. Only B needs conversion from
peft's rank-major layout to ScatterMoE's expert-major layout.
"""
smoe_A = peft_A
smoe_B = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
return smoe_A, smoe_B
def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Deprecated alias for :func:`peft_lora_to_scattermoe`."""
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
def validate_scattermoe_lora_shapes(expert_weights, lora_A, lora_B):
"""Validate LoRA tensor layout before dispatching ScatterMoE kernels."""
E, K, N = expert_weights.shape
if lora_A.dim() != 2 or lora_B.dim() != 2:
raise ValueError(
"ScatterMoE LoRA expects 2D lora_A and lora_B tensors, got "
f"lora_A={tuple(lora_A.shape)} and lora_B={tuple(lora_B.shape)}."
)
if lora_A.size(0) % E != 0:
raise ValueError(
"ScatterMoE LoRA expects lora_A rows to be divisible by the number "
f"of experts ({E}), got lora_A={tuple(lora_A.shape)}."
)
rank = lora_A.size(0) // E
expected_A = (E * rank, K)
expected_B = (N, E * rank)
if tuple(lora_A.shape) != expected_A or tuple(lora_B.shape) != expected_B:
raise ValueError(
"Invalid ScatterMoE LoRA layout for expert_weights "
f"{tuple(expert_weights.shape)}. Expected lora_A={expected_A} and "
f"lora_B={expected_B}, got lora_A={tuple(lora_A.shape)} and "
f"lora_B={tuple(lora_B.shape)}. For PEFT target_parameters, keep "
"lora_A as [E*r, K] and only convert lora_B from rank-major to "
"expert-major layout."
)

View File

@@ -34,6 +34,7 @@ from .kernels.lora_ops import (
scatter2scatter_lora,
scatter2scatter_lora_dX,
)
from .lora_layout import validate_scattermoe_lora_shapes
class ScatterMoELoRA(torch.autograd.Function):
@@ -422,11 +423,6 @@ def get_lora_params_from_wrapper(module) -> tuple:
return lora_A, lora_B, scaling
# =============================================================================
# Drop-in replacement for parallel_linear
# =============================================================================
def parallel_linear_lora(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
@@ -451,6 +447,7 @@ def parallel_linear_lora(
Otherwise falls back to standard scatter2scatter.
"""
if lora_A is not None and lora_B is not None:
validate_scattermoe_lora_shapes(expert_weights, lora_A, lora_B)
return ScatterMoELoRA.apply(
inputs,
expert_weights,

View File

@@ -222,6 +222,56 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
swiglu=cfg.liger_glu_activation,
)
elif cfg.model_config_type in ("gemma4", "gemma4_text"):
# Gemma4: offset=0 (NOT 1 like Gemma3), in_place=False required for
# gradient checkpointing compatibility, RoPE incompatible (separate q/k).
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from transformers.models.gemma4 import modeling_gemma4
if cfg.liger_rms_norm:
_OrigGemma4RMSNorm = modeling_gemma4.Gemma4RMSNorm
class _LigerGemma4RMSNorm(LigerRMSNorm):
"""LigerRMSNorm for Gemma4 with in_place=False and with_scale support."""
def __new__(cls, dim, eps=1e-6, with_scale=True):
if not with_scale:
return _OrigGemma4RMSNorm(dim, eps, with_scale=False)
return super().__new__(cls)
def __init__(self, dim, eps=1e-6, with_scale=True):
if not with_scale:
return
# offset=0.0 (standard), in_place=False (gradient checkpointing safe)
super().__init__(
dim, eps, offset=0.0, casting_mode="llama", in_place=False
)
modeling_gemma4.Gemma4RMSNorm = _LigerGemma4RMSNorm
if cfg.liger_glu_activation:
class _LigerGemma4MLP(LigerGEGLUMLP):
def __init__(self, config, layer_idx=None):
super().__init__(config)
modeling_gemma4.Gemma4TextMLP = _LigerGemma4MLP
if cfg.liger_rope:
LOG.warning(
"Liger RoPE is not compatible with Gemma4 (separate q/k application). Skipping."
)
if cfg.liger_layer_norm:
modeling_gemma4.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
modeling_gemma4.nn.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
LOG.warning(
"Liger fused linear cross entropy is not compatible with Gemma4. Skipping."
)
LOG.info(
f"Applied Liger kernels for gemma4: "
f"rms_norm={cfg.liger_rms_norm}, glu={cfg.liger_glu_activation}, "
f"rope=False (incompatible), layer_norm={cfg.liger_layer_norm}"
)
elif cfg.liger_fused_linear_cross_entropy:
try:
from .models.base import patch_lce_forward

View File

@@ -110,11 +110,36 @@ class NemoGymDataProducer(GRPODataProducer):
item["agent_ref"] = full_item["agent_ref"]
dataset_items.append(item)
# Expand by num_generations (agent produces one rollout per call)
expanded_items = []
for item in dataset_items:
for _ in range(self._num_generations):
expanded_items.append(item)
# NOTE: do NOT re-expand by num_generations here.
# ``RepeatSampler(mini_repeat_count=num_generations)`` already
# yields ``num_generations`` consecutive copies of each unique
# prompt, so ``inputs`` is a list of ``(unique_prompts_per_rank *
# num_generations)`` items — one entry per rollout. Expanding
# again here would fire ``num_generations^2`` rollouts per
# prompt per rank and make every step dogpile on a handful of
# tasks.
expanded_items = dataset_items
# Diagnostic: log what this rank is about to fire.
try:
import collections
iid_counts: collections.Counter[str | None] = collections.Counter()
for it in dataset_items:
iid_counts[
(it.get("responses_create_params", {}).get("metadata") or {}).get(
"instance_id"
)
] += 1
LOG.info(
"[RANK:%d] produce(): firing %d agent /run calls covering %d unique prompts: %s",
trainer.accelerator.process_index,
len(dataset_items),
len(iid_counts),
list(iid_counts.most_common(5)),
)
except Exception:
pass
# Call NeMo Gym agents
loop = asyncio.new_event_loop()
@@ -140,6 +165,7 @@ class NemoGymDataProducer(GRPODataProducer):
logprobs_list = []
rewards_list = []
num_turns_list: list[int] = []
for resp in responses:
parsed = _parse_agent_response(resp, eos_token_id)
prompt_ids_list.append(parsed["prompt_ids"])
@@ -147,6 +173,7 @@ class NemoGymDataProducer(GRPODataProducer):
env_mask_list.append(parsed["env_mask"])
logprobs_list.append(parsed["logprobs"])
rewards_list.append(parsed["reward"])
num_turns_list.append(parsed.get("num_turns", 0))
# Pad to tensors
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
@@ -179,22 +206,48 @@ class NemoGymDataProducer(GRPODataProducer):
tool_mask = [torch.tensor(m, device=device) for m in env_mask_list]
tool_mask = pad(tool_mask, padding_value=1, padding_side="right")
# Inject rewards into inputs so _compute_deferred_scores can use them
# The deferred scoring path calls _calculate_rewards which reads reward_funcs.
# Our passthrough reward_fn reads "env_reward" from kwargs.
# Inject per-rollout reward + num_turns into each input. Since
# ``RepeatSampler`` already yields ``num_generations`` copies of
# each prompt, ``inputs`` has ONE entry per rollout (matching
# ``rewards_list`` 1:1). No per-prompt grouping happens here —
# GRPO advantage normalization is the trainer's job downstream.
assert len(inputs) == len(rewards_list), (
f"rewards/inputs length mismatch: "
f"{len(rewards_list)} rewards vs {len(inputs)} inputs"
)
for i, inp in enumerate(inputs):
# Each input gets rewards for its num_generations rollouts
start = i * self._num_generations
end = start + self._num_generations
inp["env_reward"] = rewards_list[start:end]
inp["env_reward"] = rewards_list[i]
inp["num_turns"] = num_turns_list[i]
# Expand inputs to match expanded rollouts (num_generations copies)
expanded_inputs = []
for inp in inputs:
for g in range(self._num_generations):
expanded_inp = dict(inp)
expanded_inp["env_reward"] = inp["env_reward"][g]
expanded_inputs.append(expanded_inp)
# One expanded_input per rollout (already correct count because
# inputs has num_generations copies baked in by the sampler).
expanded_inputs = [dict(inp) for inp in inputs]
# Log rollout-level stats to wandb from rank 0. These are the
# true agent-side metrics (not the tokenized TRL view) — so
# num_turns reflects how many /run iterations each rollout
# actually took before finishing or hitting max_turns.
if is_main and num_turns_list:
try:
import wandb
if wandb.run is not None:
import statistics as _stats
nonzero = sum(1 for r in rewards_list if r > 0)
log_payload = {
"rollout/num_turns/mean": float(_stats.mean(num_turns_list)),
"rollout/num_turns/min": float(min(num_turns_list)),
"rollout/num_turns/max": float(max(num_turns_list)),
"rollout/reward/mean": float(_stats.mean(rewards_list)),
"rollout/reward/nonzero_frac": (
nonzero / len(rewards_list) if rewards_list else 0.0
),
"rollout/n_samples": float(len(rewards_list)),
}
wandb.log(log_payload, commit=False)
except Exception as exc: # never let metric logging break training
LOG.warning("rollout wandb log failed: %s", exc)
# Decode completions for reward functions
completions = trainer.processing_class.batch_decode(

View File

@@ -19,6 +19,7 @@ Supports two modes:
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Union
from axolotl.integrations.base import BasePlugin
@@ -30,6 +31,107 @@ if TYPE_CHECKING:
LOG = get_logger(__name__)
# ---- vLLM weight-sync transport probe ------------------------------------
@dataclass
class VLLMWeightSyncCapabilities:
"""What weight-sync routes a vLLM server actually exposes.
Discovered once at ``pre_model_load`` time by fetching the server's
``/openapi.json``. Drives the transport-selection table below.
"""
nccl: bool = False # /init_communicator/ + /update_named_param/
lora_filesystem: bool = False # /v1/load_lora_adapter (vLLM native)
lora_axolotl: bool = False # /set_lora_adapter/ (axolotl serve_lora extension)
http_full: bool = False # /http_update_weights/ (axolotl serve_lora extension)
probed: bool = False
probe_error: str | None = None
routes: list[str] = field(default_factory=list)
@property
def any_full_param_sync(self) -> bool:
"""True if at least one transport can push full-model weights."""
return self.nccl or self.http_full
@property
def any_lora_sync(self) -> bool:
"""True if at least one transport can push LoRA adapters."""
return self.lora_filesystem or self.lora_axolotl or self.nccl
def probe_vllm_weight_sync(
base_url: str, timeout: float = 5.0
) -> VLLMWeightSyncCapabilities:
"""Detect which weight-sync routes the configured vLLM server exposes.
Uses the server's FastAPI ``/openapi.json`` — every weight-sync transport
we care about is mounted as a POST route there. Falls back to all-False
on any error so the caller can still decide what to do (typically: raise
a clear error rather than silently no-op).
"""
import requests
caps = VLLMWeightSyncCapabilities()
try:
r = requests.get(f"{base_url.rstrip('/')}/openapi.json", timeout=timeout)
r.raise_for_status()
spec = r.json()
routes = sorted((spec.get("paths") or {}).keys())
caps.routes = routes
caps.nccl = "/init_communicator/" in routes and "/update_named_param/" in routes
caps.lora_filesystem = "/v1/load_lora_adapter" in routes
caps.lora_axolotl = "/set_lora_adapter/" in routes
caps.http_full = "/http_update_weights/" in routes
caps.probed = True
except Exception as exc:
caps.probe_error = f"{type(exc).__name__}: {exc}"
LOG.warning(
"NeMo Gym: failed to probe vLLM /openapi.json at %s%s. "
"Will fall back to LoRA-only behavior.",
base_url,
caps.probe_error,
)
return caps
def select_weight_sync_transport(
caps: VLLMWeightSyncCapabilities,
*,
has_lora: bool,
vllm_lora_sync_pref: bool,
) -> str:
"""Pick the right transport for a (server caps, model type) combo.
Returns one of: ``"lora_filesystem"``, ``"nccl"``, ``"http_full"``, or
``"none"``. The caller decides what to do with ``"none"`` (typically:
raise an error explaining the misconfiguration).
Selection table:
LoRA model + lora endpoint + lora-sync pref → lora_filesystem
LoRA model + lora endpoint → lora_filesystem
LoRA model + nccl endpoint → nccl (broadcast merged adapter)
Full model + nccl endpoint → nccl
Full model + http endpoint → http_full
anything else → none
"""
if has_lora:
if (caps.lora_filesystem or caps.lora_axolotl) and vllm_lora_sync_pref:
return "lora_filesystem"
if caps.lora_filesystem or caps.lora_axolotl:
return "lora_filesystem"
if caps.nccl:
return "nccl"
return "none"
# Full-parameter model
if caps.nccl:
return "nccl"
if caps.http_full:
return "http_full"
return "none"
class NemoGymPlugin(BasePlugin):
"""Plugin for NVIDIA NeMo Gym integration with Axolotl.
@@ -50,37 +152,69 @@ class NemoGymPlugin(BasePlugin):
self._reward_fn = None
self._dataset_lookup = None
self._agent_servers = {}
self._vllm_caps: VLLMWeightSyncCapabilities | None = None
def get_input_args(self):
return "axolotl.integrations.nemo_gym.NemoGymArgs"
def pre_model_load(self, cfg):
"""Apply monkeypatches before trainer creation."""
"""Probe vLLM weight-sync routes and conditionally bypass NCCL init.
Replaces the previous unconditional ``init_communicator`` monkey-patch
with a probe of the configured vLLM server's ``/openapi.json``. We only
bypass NCCL init when the server we're talking to actually lacks the
``/init_communicator/`` route (i.e. stock ``vllm serve``); against
TRL/axolotl serve modules that DO expose NCCL routes, we leave the
standard TRL flow alone so full-finetune training can sync weights.
"""
if not cfg.nemo_gym_enabled:
return
# Always skip NCCL communicator init in NeMo Gym mode.
# NeMo Gym uses its own vLLM server (standard OpenAI API), not the TRL
# colocate/NCCL path. The NCCL init fails with vLLM V1 and standard servers.
trl_cfg = getattr(cfg, "trl", None)
if trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server":
if not (trl_cfg and getattr(trl_cfg, "vllm_mode", "server") == "server"):
return
host = getattr(trl_cfg, "vllm_server_host", None) or "127.0.0.1"
port = getattr(trl_cfg, "vllm_server_port", None) or 8000
base_url = f"http://{host}:{port}"
self._vllm_caps = probe_vllm_weight_sync(base_url)
if self._vllm_caps.probed:
LOG.info(
"NeMo Gym: vLLM weight-sync probe @ %s — nccl=%s lora_native=%s "
"lora_axolotl=%s http_full=%s",
base_url,
self._vllm_caps.nccl,
self._vllm_caps.lora_filesystem,
self._vllm_caps.lora_axolotl,
self._vllm_caps.http_full,
)
# Only bypass NCCL init when the server doesn't speak it. If NCCL is
# available we leave VLLMClient.init_communicator alone so the
# standard TRL sync flow can run for full-parameter training.
if not self._vllm_caps.nccl:
self._patch_skip_nccl_init()
def _patch_skip_nccl_init(self):
"""Monkeypatch VLLMClient.init_communicator to no-op.
NeMo Gym uses its own vLLM server (standard OpenAI API or custom LoRA
serve script). The NCCL communicator is not needed and fails with both
vLLM V1 engine and standard OpenAI server mode.
Only called when the configured vLLM server doesn't expose
``/init_communicator/`` (e.g. stock ``vllm serve``). In that case
TRL's standard ``init_communicator`` would 404 inside trainer
construction; we no-op it so the LoRA filesystem path can install
its own sync in ``post_trainer_create``.
"""
try:
from trl.generation.vllm_client import VLLMClient
VLLMClient._original_init_communicator = VLLMClient.init_communicator
VLLMClient.init_communicator = lambda self, **kwargs: LOG.info(
"Skipping NCCL init_communicator (LoRA sync mode)"
"Skipping NCCL init_communicator (server has no /init_communicator/)"
)
LOG.info(
"Patched VLLMClient.init_communicator to no-op (server has no NCCL routes)"
)
LOG.info("Patched VLLMClient.init_communicator to no-op for LoRA sync")
except Exception as exc:
LOG.warning(f"Failed to patch VLLMClient: {exc}")
@@ -234,30 +368,80 @@ class NemoGymPlugin(BasePlugin):
verify_timeout = cfg.nemo_gym_verify_timeout or 30
multi_turn = cfg.nemo_gym_multi_turn or False
# Handle weight sync. NeMo Gym skips NCCL init, so we need to either:
# - Install LoRA sync (when vllm_lora_sync=True)
# - Or no-op sync_weights (when using standard vLLM server)
# Pick a weight-sync transport based on what the configured vLLM
# server actually exposes (see ``pre_model_load`` probe) and what
# kind of model we're training. The selection table is documented
# in ``select_weight_sync_transport``.
trl_cfg = getattr(cfg, "trl", None)
if hasattr(trainer, "vllm_generation") and trainer.vllm_generation:
vllm_gen = trainer.vllm_generation
if trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False):
adapter = getattr(cfg, "adapter", None)
has_lora = adapter in ("lora", "qlora")
vllm_lora_sync_pref = bool(
trl_cfg and getattr(trl_cfg, "vllm_lora_sync", False)
)
caps = self._vllm_caps or VLLMWeightSyncCapabilities()
transport = select_weight_sync_transport(
caps,
has_lora=has_lora,
vllm_lora_sync_pref=vllm_lora_sync_pref,
)
if transport == "lora_filesystem":
self._setup_lora_sync(trainer)
# Verify the vLLM server supports runtime LoRA loading
self._check_lora_endpoint(vllm_gen)
else:
# No NCCL, no LoRA sync — skip all weight sync paths
vllm_gen.sync_weights = lambda: LOG.debug(
"Weight sync skipped (NeMo Gym mode)"
LOG.info("NeMo Gym weight sync: LoRA filesystem")
elif transport == "nccl":
# Standard TRL NCCL path. We leave ``VLLMClient.init_communicator``
# alone (pre_model_load only patched it when the probe found no
# NCCL route) so the trainer's normal weight-sync flow runs.
LOG.info(
"NeMo Gym weight sync: NCCL (server exposes /init_communicator/)"
)
type(vllm_gen).sync_weights = lambda self: LOG.debug(
"Weight sync skipped (NeMo Gym mode)"
elif transport == "http_full":
# Full-parameter HTTP sync — implementation lands in step 3.
# For now, fail loudly so users know the path is detected but
# not yet wired up, instead of silently no-oping like before.
raise NotImplementedError(
"NeMo Gym + full fine-tune + HTTP weight sync is detected "
"but the client-side sync helper is not yet implemented "
"(planned). Use `adapter: lora|qlora` for now, or use a "
"vLLM serve module that exposes /init_communicator/ for "
"NCCL sync."
)
# Also patch the async trainer's internal sync method
if hasattr(trainer, "_maybe_sync_vllm_weights"):
trainer._maybe_sync_vllm_weights = lambda: LOG.debug(
"Async weight sync skipped (NeMo Gym mode)"
else: # transport == "none"
# No viable sync path. Build a precise error so the user knows
# exactly what's missing and how to fix it.
if not caps.probed:
msg = (
"could not probe the vLLM server's "
f"/openapi.json: {caps.probe_error}. "
"Verify that vLLM is reachable at "
f"{getattr(trl_cfg, 'vllm_server_host', '?')}:"
f"{getattr(trl_cfg, 'vllm_server_port', '?')}."
)
LOG.info("Disabled weight sync (NeMo Gym mode, no LoRA sync)")
elif has_lora:
msg = (
"the vLLM server has neither NCCL routes "
"(/init_communicator/) nor a LoRA-loading route "
"(/v1/load_lora_adapter or /set_lora_adapter/). "
"Restart vLLM with `--enable-lora --max-lora-rank N "
"VLLM_ALLOW_RUNTIME_LORA_UPDATING=1` for the stock "
"server, or use `axolotl vllm-serve` for the "
"NCCL-capable serve module."
)
else:
msg = (
"the vLLM server exposes no full-parameter sync route "
"(/init_communicator/ for NCCL or /http_update_weights/ "
"for HTTP). Use `axolotl vllm-serve` (which has both) "
"or set `adapter: lora|qlora`."
)
raise ValueError(
f"NeMo Gym: no usable weight-sync transport — {msg} Without "
"weight sync the trainer's gradient updates never reach the "
"rollout policy (functionally a no-op trainer)."
)
if multi_turn:
self._wire_multi_turn(cfg, trainer, model_name, verify_timeout)

View File

@@ -130,21 +130,41 @@ def start_servers(
)
def get_server_configs(head_port: int = 11000) -> dict:
def get_server_configs(head_port: int = 11000, timeout: float = 30.0) -> dict:
"""Fetch the global config from the NeMo Gym head server.
Retries up to 3 times with exponential backoff. The default per-attempt
timeout is 30s (raised from the original 5s) because head servers can
be slow to respond when they're concurrently serving rollouts from a
prior training run. A 5s timeout was empirically too tight to survive
a kill-and-relaunch cycle.
Returns:
Dict mapping server_name -> server config.
"""
response = requests.get(
f"http://127.0.0.1:{head_port}/global_config_dict_yaml", timeout=5
url = f"http://127.0.0.1:{head_port}/global_config_dict_yaml"
last_exc: Exception | None = None
for attempt in (1, 2, 3):
try:
response = requests.get(url, timeout=timeout)
response.raise_for_status()
result = yaml.safe_load(response.text)
# NeMo Gym head server double-encodes: YAML string inside a YAML string
if isinstance(result, str):
result = yaml.safe_load(result)
return result
except (requests.exceptions.RequestException, OSError) as exc:
last_exc = exc
LOG.warning(
"NeMo Gym head probe attempt %d/3 failed: %s. Retrying...",
attempt,
type(exc).__name__,
)
if attempt < 3:
time.sleep(2.0 * attempt)
raise RuntimeError(
f"NeMo Gym head server at {url} did not respond after 3 attempts: {last_exc}"
)
response.raise_for_status()
result = yaml.safe_load(response.text)
# NeMo Gym head server double-encodes: YAML string inside a YAML string
if isinstance(result, str):
result = yaml.safe_load(result)
return result
def get_agent_servers(

View File

@@ -0,0 +1,593 @@
"""
Fused RMSNorm + RoPE Triton kernel for Gemma 4.
Fuses three operations into one kernel launch:
1. RMSNorm: x_norm = (x / sqrt(mean(x^2) + eps)) * weight
2. RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
3. (optional) RMSNorm without scale (for v_norm)
This eliminates two intermediate tensor materializations per Q/K path;
churn from rotate_half / apply_rotary_pos_emb.
Shapes:
X: (rows, head_dim) — flattened from (batch, seq_len, num_heads, head_dim)
W: (head_dim,) — RMSNorm weight (None for with_scale=False)
cos: (rows, head_dim) — flattened from (batch, seq_len, 1, head_dim) after broadcast
sin: (rows, head_dim) — same as cos
"""
import math
import operator
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import (
calculate_settings,
compare_version,
ensure_contiguous,
torch_to_triton_dtype,
)
from liger_kernel.utils import is_npu_available
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
from triton.language.extra.libdevice import rsqrt
except ModuleNotFoundError:
from triton.language.extra.cuda.libdevice import rsqrt
else:
from triton.language.math import rsqrt
@triton.jit
def _rms_norm_rope_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
W_ptr,
COS_ptr,
COS_row_stride,
SIN_ptr,
SIN_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_cols,
n_rot,
n_heads,
eps,
HAS_WEIGHT: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Fused forward:
x_norm = x / rms(x) [* weight] (RMSNorm, full n_cols)
y[..., :n_rot] = rope(x_norm[..., :n_rot])
y[..., n_rot:] = x_norm[..., n_rot:] (pass-through for partial rotary)
rotate_half swaps first/second halves and negates the first, restricted
to the rotary span [0, n_rot):
rotate_half([a, b]) = [-b, a] where len(a) = len(b) = n_rot/2
For the partial-rotary pass-through region we load cos with default 1.0
and sin with default 0.0 outside [0, n_rot), so the same formula
`Y = X_norm * cos + X_rot_norm * sin` collapses to `Y = X_norm`.
cos/sin are indexed by row_idx // n_heads to handle per-head broadcast
(cos/sin have shape (B*S, n_rot) while X has shape (B*S*H, n_cols)).
"""
row_idx = tl.program_id(0).to(tl.int64)
# cos/sin row: divide by n_heads since cos/sin are (B*S, n_rot)
cs_row_idx = row_idx // n_heads
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
rot_mask_col = col_offsets < n_rot
half_rot = n_rot // 2
# Load input row
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
X_dtype = X_row.dtype
X_fp32 = X_row.to(tl.float32)
# RMSNorm: compute 1/rms over the full row (rotary + pass-through)
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
rstd = rsqrt(mean_sq + eps)
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
# Normalize
X_norm = X_fp32 * rstd
# Apply weight if present (with_scale=True)
if HAS_WEIGHT:
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
X_norm = X_norm * W_row
# RoPE: load cos/sin (broadcast across heads). For col >= n_rot we get
# cos=1, sin=0 so the formula leaves X_norm untouched.
cos_row = tl.load(
COS_ptr + cs_row_idx * COS_row_stride + col_offsets,
mask=rot_mask_col,
other=1.0,
).to(tl.float32)
sin_row = tl.load(
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets,
mask=rot_mask_col,
other=0.0,
).to(tl.float32)
# rotate_half within [0, n_rot):
# for col < half_rot: take -X_norm[col + half_rot]
# for col in [half_rot, n_rot): take X_norm[col - half_rot]
# For col >= n_rot the rotation is irrelevant (sin = 0 zeros it out).
rot_offsets = tl.where(
col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot
)
rot_load_mask = (rot_offsets < n_cols) & rot_mask_col
X_rot = tl.load(
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_load_mask, other=0
).to(tl.float32)
# Re-normalize the rotated values
X_rot_norm = X_rot * rstd
if HAS_WEIGHT:
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_load_mask, other=0).to(tl.float32)
X_rot_norm = X_rot_norm * W_rot
# Negate the first half (rotate_half negates x2, which becomes the first half)
sign = tl.where(col_offsets < half_rot, -1.0, 1.0)
X_rot_norm = X_rot_norm * sign
# Final RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
Y_row = X_norm * cos_row + X_rot_norm * sin_row
tl.store(
Y_ptr + row_idx * Y_row_stride + col_offsets,
Y_row.to(X_dtype),
mask=mask,
)
@triton.jit
def _rms_norm_rope_backward_kernel(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
W_ptr,
COS_ptr,
COS_row_stride,
SIN_ptr,
SIN_row_stride,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_rows,
n_cols,
n_rot,
n_heads,
rows_per_program,
HAS_WEIGHT: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Backward for Y = RoPE(RMSNorm(X, W)) with optional partial rotary
(`n_rot <= n_cols`).
For col < n_rot the standard RoPE adjoint applies. For col >= n_rot the
output is just the normalized row, so dN[col] = dY[col] (achieved by
loading cos with default 1.0 and forcing the rotate-half contribution
to zero outside the rotary span).
cos/sin indexed by row_idx // n_heads for per-head broadcast.
"""
row_block_id = tl.program_id(0).to(tl.int64)
row_start = row_block_id * rows_per_program
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
rot_mask_col = col_offsets < n_rot
half_rot = n_rot // 2
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
if HAS_WEIGHT:
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
for row_idx in range(row_start, row_end):
cs_row_idx = row_idx // n_heads
dY_row = tl.load(
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
X_row = tl.load(
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
cos_row = tl.load(
COS_ptr + cs_row_idx * COS_row_stride + col_offsets,
mask=rot_mask_col,
other=1.0,
).to(tl.float32)
# dN = dY * cos + rotate_half^T(dY * sin) (within the rotary span)
# rotate_half^T([a, b]) = [b, -a] (adjoint of rotate_half)
#
# For col >= n_rot the formula must collapse to dN = dY (since the
# forward is just a pass-through). cos defaults to 1.0 above; the
# rotate-half contribution is masked to zero below.
rot_offsets = tl.where(
col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot
)
rot_load_mask = (rot_offsets < n_cols) & rot_mask_col
dY_rot = tl.load(
dY_ptr + row_idx * dY_row_stride + rot_offsets,
mask=rot_load_mask,
other=0,
).to(tl.float32)
sin_rot = tl.load(
SIN_ptr + cs_row_idx * SIN_row_stride + rot_offsets,
mask=rot_load_mask,
other=0,
).to(tl.float32)
adj_sign = tl.where(col_offsets < half_rot, 1.0, -1.0)
rotate_term = dY_rot * sin_rot * adj_sign
# Zero out rotate-half contribution outside the rotary span.
rotate_term = tl.where(rot_mask_col, rotate_term, 0.0)
dN = dY_row * cos_row + rotate_term
# Pre-weight normalized: n = rstd * x
n = X_row * rstd
if HAS_WEIGHT:
dW_acc += dN * n
dm = dN * W_row
else:
dm = dN
# RMSNorm backward: dX = rstd * (dm - (1/n_cols) * rstd^2 * dot(dm, X) * X)
dot_dm_x = tl.sum(dm * X_row, axis=0)
dX_row = rstd * (dm - (1.0 / n_cols) * rstd * rstd * dot_dm_x * X_row)
tl.store(
dX_ptr + row_idx * dX_row_stride + col_offsets,
dX_row.to(X_dtype),
mask=mask,
)
if HAS_WEIGHT:
tl.store(
dW_ptr + row_block_id * dW_row_stride + col_offsets,
dW_acc,
mask=mask,
)
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads, n_rot):
"""
Args:
X: (B*S*H, head_dim) — contiguous, flattened from (B, S, H, D)
W: (head_dim,) or None — RMSNorm weight
cos: (B*S, n_rot) — position embeddings (broadcast across heads)
sin: (B*S, n_rot) — position embeddings (broadcast across heads)
eps: float
n_heads: int — number of attention heads (for cos/sin indexing)
n_rot: int — rotary dim (== head_dim for full rotary, < head_dim for
partial rotary). Must be even and ``<= head_dim``.
Returns:
Y, X_saved, RSTD, BLOCK_SIZE, num_warps
"""
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
has_weight = W is not None
Y = torch.empty_like(X)
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
_rms_norm_rope_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
X,
X.stride(0),
W if has_weight else X, # dummy pointer when no weight
cos,
cos.stride(0),
sin,
sin.stride(0),
RSTD,
RSTD.stride(0),
n_cols,
n_rot,
n_heads,
eps,
HAS_WEIGHT=has_weight,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return Y, X, RSTD, BLOCK_SIZE, num_warps
def rms_norm_rope_backward(
dY, X, W, cos, sin, RSTD, n_heads, n_rot, BLOCK_SIZE, num_warps
):
n_rows, n_cols = dY.shape
has_weight = W is not None
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
rows_per_program = math.ceil(n_rows / sm_count)
dX = torch.empty_like(X)
if has_weight:
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=X.device)
else:
_dW = torch.empty((1, n_cols), dtype=torch.float32, device=X.device)
_rms_norm_rope_backward_kernel[(sm_count,)](
dY,
dY.stride(0),
dX,
dX.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
W if has_weight else X, # dummy
cos,
cos.stride(0),
sin,
sin.stride(0),
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0),
n_rows,
n_cols,
n_rot,
n_heads,
rows_per_program,
HAS_WEIGHT=has_weight,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
dW = _dW.sum(dim=0).to(W.dtype) if has_weight else None
return dX, dW
class FusedRMSNormRoPEFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, X, W, cos, sin, eps, n_heads, n_rot):
"""
X: (B*S*H, head_dim)
W: (head_dim,) or None
cos: (B*S, n_rot) — broadcast across heads
sin: (B*S, n_rot) — broadcast across heads
n_heads: int
n_rot: int — rotary dim (<= head_dim)
"""
Y, X_saved, RSTD, BLOCK_SIZE, num_warps = rms_norm_rope_forward(
X,
W,
cos,
sin,
eps,
n_heads,
n_rot,
)
ctx.eps = eps
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.n_heads = n_heads
ctx.n_rot = n_rot
ctx.has_weight = W is not None
ctx.save_for_backward(X_saved, W, cos, sin, RSTD)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, W, cos, sin, RSTD = ctx.saved_tensors
dX, dW = rms_norm_rope_backward(
dY,
X,
W,
cos,
sin,
RSTD,
ctx.n_heads,
ctx.n_rot,
ctx.BLOCK_SIZE,
ctx.num_warps,
)
return dX, dW, None, None, None, None, None
def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
"""
Apply fused RMSNorm + (partial) RoPE.
Args:
x: (batch, seq_len, num_heads, head_dim) — after projection + view
weight: (head_dim,) — RMSNorm weight, or None for no-scale norm
cos: (batch, seq_len, n_rot) — from RotaryEmbedding. ``n_rot``
must be even and ``<= head_dim``. When ``n_rot < head_dim``
the trailing ``head_dim - n_rot`` columns are RMSNorm-only
(partial-rotary pass-through), matching stock Gemma 4 with
``partial_rotary_factor < 1.0``.
sin: (batch, seq_len, n_rot) — same shape as ``cos``
eps: float — RMSNorm epsilon
Returns:
y: (batch, seq_len, num_heads, head_dim) — normalized + rotated
"""
shape = x.shape # (B, S, H, D)
B, S, H, D = shape
n_rot = cos.shape[-1]
if sin.shape[-1] != n_rot:
raise ValueError(
f"cos and sin must have the same last dim, got cos={cos.shape[-1]} "
f"sin={sin.shape[-1]}"
)
if n_rot > D:
raise ValueError(f"rotary dim ({n_rot}) cannot exceed head_dim ({D})")
if n_rot % 2 != 0:
raise ValueError(f"rotary dim must be even, got {n_rot}")
# Flatten to 2D: (B*S*H, D)
x_flat = x.reshape(-1, D).contiguous()
# cos/sin may broadcast over the batch dim (e.g. (1, S, n_rot) when
# all sequences share the same rotary positions). The kernel needs a
# dense (B*S, n_rot) buffer so that row_idx // n_heads maps cleanly
# onto a single (b, s) pair, so expand-then-contiguous to materialize
# the per-batch broadcast. Expand is a no-op when B == cos.shape[0].
if cos.shape[0] != B:
if cos.shape[0] != 1:
raise ValueError(
f"cos/sin batch dim ({cos.shape[0]}) must be 1 or equal "
f"to x batch dim ({B})"
)
cos = cos.expand(B, S, n_rot)
sin = sin.expand(B, S, n_rot)
cos_flat = cos.reshape(B * S, n_rot).contiguous()
sin_flat = sin.reshape(B * S, n_rot).contiguous()
y_flat = FusedRMSNormRoPEFunction.apply(
x_flat, weight, cos_flat, sin_flat, eps, H, n_rot
)
return y_flat.view(shape)
@triton.jit
def _rms_norm_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_cols,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""RMSNorm without scale weight: y = x / rms(x)"""
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
X_dtype = X_row.dtype
X_fp32 = X_row.to(tl.float32)
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
rstd = rsqrt(mean_sq + eps)
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
Y_row = X_fp32 * rstd
tl.store(Y_ptr + row_idx * Y_row_stride + col_offsets, Y_row.to(X_dtype), mask=mask)
@triton.jit
def _rms_norm_noscale_backward_kernel(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
RSTD_ptr,
RSTD_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""Backward for y = x * rstd (no weight)."""
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
dY_row = tl.load(
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
X_row = tl.load(
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
dot_dy_x = tl.sum(dY_row * X_row, axis=0)
dX_row = rstd * (dY_row - (1.0 / n_cols) * rstd * rstd * dot_dy_x * X_row)
tl.store(
dX_ptr + row_idx * dX_row_stride + col_offsets, dX_row.to(X_dtype), mask=mask
)
class FusedRMSNormNoScaleFunction(torch.autograd.Function):
"""RMSNorm without learnable scale — used for Gemma4's v_norm."""
@staticmethod
@ensure_contiguous
def forward(ctx, X, eps):
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Y = torch.empty_like(X)
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
_rms_norm_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
X,
X.stride(0),
RSTD,
RSTD.stride(0),
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.save_for_backward(X, RSTD)
ctx.n_cols = n_cols
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, RSTD = ctx.saved_tensors
n_rows = X.shape[0]
dX = torch.empty_like(X)
_rms_norm_noscale_backward_kernel[(n_rows,)](
dY,
dY.stride(0),
dX,
dX.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
RSTD,
RSTD.stride(0),
ctx.n_cols,
BLOCK_SIZE=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
)
return dX, None
def fused_rms_norm_noscale(x, eps=1e-6):
"""
RMSNorm without scale for v_norm.
Args:
x: (batch, seq_len, num_heads, head_dim)
Returns:
y: same shape, normalized
"""
shape = x.shape
x_flat = x.reshape(-1, shape[-1])
y_flat = FusedRMSNormNoScaleFunction.apply(x_flat, eps)
return y_flat.view(shape)

View File

@@ -1297,6 +1297,339 @@ def apply_lora_qkv(
return Q, K, V
class LoRA_QK(torch.autograd.Function):
"""Optimized LoRA QK implementation for models where v_proj is None.
Used by models like Gemma4 with attention_k_eq_v=True, where key states are
reused as value states. Only Q and K projections are fused; the caller
returns K a second time as V so that autograd accumulates key+value gradients
into a single dK.
Supports bias, dropout, and DoRA (Weight-Decomposed Low-Rank Adaptation).
"""
@staticmethod
@torch_amp_custom_fwd
def forward(
ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor,
X_drop: torch.Tensor | None,
# Q params
q_weight: torch.Tensor,
q_bias: torch.Tensor | None,
q_quant: QuantState | None,
q_A: torch.Tensor | None,
q_B: torch.Tensor | None,
q_scale: float,
q_lora_bias: torch.Tensor | None,
q_magnitude: torch.Tensor | None,
# K params
k_weight: torch.Tensor,
k_bias: torch.Tensor | None,
k_quant: QuantState | None,
k_A: torch.Tensor | None,
k_B: torch.Tensor | None,
k_scale: float,
k_lora_bias: torch.Tensor | None,
k_magnitude: torch.Tensor | None,
# Flags
inplace: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
has_dropout = X_drop is not None
has_dora = q_magnitude is not None
if has_dora:
dtype = X.dtype
X_lora = X_drop if has_dropout else X
# Compute Q with DoRA
Q_base = matmul_lora(X, q_weight, None, q_quant, None, None, None)
Q_lora = _lora_only(X_lora, q_A, q_B, q_scale, q_lora_bias, dtype)
q_mag_scale = _compute_dora_scale(
q_weight, q_quant, q_A, q_B, q_scale, q_magnitude, dtype
)
Q = q_mag_scale.unsqueeze(0) * (Q_base + Q_lora)
if q_bias is not None:
Q = Q + q_bias
# Compute K with DoRA
K_base = matmul_lora(X, k_weight, None, k_quant, None, None, None)
K_lora = _lora_only(X_lora, k_A, k_B, k_scale, k_lora_bias, dtype)
k_mag_scale = _compute_dora_scale(
k_weight, k_quant, k_A, k_B, k_scale, k_magnitude, dtype
)
K = k_mag_scale.unsqueeze(0) * (K_base + K_lora)
if k_bias is not None:
K = K + k_bias
Q_combined = Q_base + Q_lora
K_combined = K_base + K_lora
ctx.save_for_backward(
X,
X_drop if has_dropout else X,
q_A.to(dtype) if q_A is not None else q_A,
q_B.to(dtype) if q_B is not None else q_B,
k_A.to(dtype) if k_A is not None else k_A,
k_B.to(dtype) if k_B is not None else k_B,
q_magnitude,
k_magnitude,
q_mag_scale,
k_mag_scale,
Q_combined,
K_combined,
q_lora_bias,
k_lora_bias,
)
else:
# Standard LoRA (with optional dropout and bias)
Q = matmul_lora(
X,
q_weight,
q_bias,
q_quant,
q_A,
q_B,
q_scale,
X_drop=X_drop,
lora_bias=q_lora_bias,
)
K = matmul_lora(
X,
k_weight,
k_bias,
k_quant,
k_A,
k_B,
k_scale,
X_drop=X_drop,
lora_bias=k_lora_bias,
)
dtype = X.dtype
ctx.save_for_backward(
X,
X_drop if has_dropout else X,
q_A.to(dtype) if q_A is not None else q_A,
q_B.to(dtype) if q_B is not None else q_B,
k_A.to(dtype) if k_A is not None else k_A,
k_B.to(dtype) if k_B is not None else k_B,
q_lora_bias,
k_lora_bias,
)
ctx.scales = (q_scale, k_scale)
ctx.quants = (q_quant, k_quant)
ctx.weights = (q_weight, k_weight)
ctx.inplace = inplace
ctx.has_dropout = has_dropout
ctx.has_dora = has_dora
return Q, K
@staticmethod
@torch_amp_custom_bwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
q_grad: torch.Tensor,
k_grad: torch.Tensor,
):
q_weight, k_weight = ctx.weights
q_quant, k_quant = ctx.quants
q_scale, k_scale = ctx.scales
has_dropout = ctx.has_dropout
has_dora = ctx.has_dora
if has_dora:
(
X,
X_lora,
A_q,
B_q,
A_k,
B_k,
q_magnitude,
k_magnitude,
q_mag_scale,
k_mag_scale,
Q_combined,
K_combined,
q_lora_bias,
k_lora_bias,
) = ctx.saved_tensors
else:
(
X,
X_lora,
A_q,
B_q,
A_k,
B_k,
q_lora_bias,
k_lora_bias,
) = ctx.saved_tensors
q_magnitude = k_magnitude = None
q_mag_scale = k_mag_scale = None
Q_combined = K_combined = None
batch, seq_len = X.shape[:2]
q_grad = q_grad.view(-1, q_grad.shape[-1])
k_grad = k_grad.reshape(-1, k_grad.shape[-1])
X = X.view(-1, X.shape[-1])
X_lora = X_lora.view(-1, X_lora.shape[-1])
d_q_mag = d_k_mag = None
d_q_lora_bias = d_k_lora_bias = None
if has_dora:
Q_combined = Q_combined.view(-1, Q_combined.shape[-1])
K_combined = K_combined.view(-1, K_combined.shape[-1])
d_q_mag = (q_grad * Q_combined).sum(dim=0) * q_mag_scale / q_magnitude
d_k_mag = (k_grad * K_combined).sum(dim=0) * k_mag_scale / k_magnitude
q_grad = q_grad * q_mag_scale.unsqueeze(0)
k_grad = k_grad * k_mag_scale.unsqueeze(0)
# LoRA bias gradients
if q_lora_bias is not None:
d_q_lora_bias = q_scale * q_grad.sum(dim=0)
if k_lora_bias is not None:
d_k_lora_bias = k_scale * k_grad.sum(dim=0)
X_lora_t = X_lora.t()
d_A_q = d_B_q = d_A_k = d_B_k = None
grad_B_q = grad_B_k = None
if A_q is not None and B_q is not None:
grad_B_q = q_grad @ B_q
d_A_q = torch.empty_like(A_q.t())
d_B_q = torch.empty_like(B_q.t())
d_A_q.addmm_(X_lora_t, grad_B_q, alpha=q_scale, beta=0)
d_B_q.addmm_(A_q @ X_lora_t, q_grad, alpha=q_scale, beta=0)
if A_k is not None and B_k is not None:
grad_B_k = k_grad @ B_k
d_A_k = torch.empty_like(A_k.t())
d_B_k = torch.empty_like(B_k.t())
d_A_k.addmm_(X_lora_t, grad_B_k, alpha=k_scale, beta=0)
d_B_k.addmm_(A_k @ X_lora_t, k_grad, alpha=k_scale, beta=0)
# Base path input gradient
out_buffer = X if ctx.inplace else None
q_weight_t = dequantize(q_weight, q_quant)
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
del q_weight_t
k_weight_t = dequantize(k_weight, k_quant)
grad_X.addmm_(k_grad, k_weight_t)
del k_weight_t
# LoRA path input gradient
if has_dropout:
grad_X_drop = torch.zeros_like(X_lora)
if grad_B_q is not None:
grad_X_drop.addmm_(grad_B_q, A_q, alpha=q_scale)
if grad_B_k is not None:
grad_X_drop.addmm_(grad_B_k, A_k, alpha=k_scale)
else:
grad_X_drop = None
if grad_B_q is not None:
grad_X.addmm_(grad_B_q, A_q, alpha=q_scale)
if grad_B_k is not None:
grad_X.addmm_(grad_B_k, A_k, alpha=k_scale)
if d_A_q is not None:
d_A_q = d_A_q.t()
d_B_q = d_B_q.t() # type: ignore[union-attr]
if d_A_k is not None:
d_A_k = d_A_k.t()
d_B_k = d_B_k.t() # type: ignore[union-attr]
grad_X = grad_X.view(batch, seq_len, -1)
if grad_X_drop is not None:
grad_X_drop = grad_X_drop.view(batch, seq_len, -1)
# Return gradients for all forward inputs:
# X, X_drop,
# q: weight, bias, quant, A, B, scale, lora_bias, magnitude
# k: weight, bias, quant, A, B, scale, lora_bias, magnitude
# inplace
return (
grad_X,
grad_X_drop,
# Q
None,
None,
None,
d_A_q,
d_B_q,
None,
d_q_lora_bias,
d_q_mag,
# K
None,
None,
None,
d_A_k,
d_B_k,
None,
d_k_lora_bias,
d_k_mag,
# inplace
None,
)
def apply_lora_qk(
self, X: torch.Tensor, inplace: bool = True
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Applies LoRA to compute Query and Key projections for models where v_proj is None.
When v_proj is None (e.g. Gemma4 attention_k_eq_v), key states are reused as
value states. Returns (Q, K, K) — the caller's patched forward will use K as V.
Because K is returned twice, autograd accumulates gradients from both the key and
value paths into dK before calling LoRA_QK.backward.
Supports bias, dropout, and DoRA.
"""
QW, Qb, QW_quant, QA, QB, QS, Qlb, Qdrop, Qmag = get_lora_parameters(self.q_proj)
KW, Kb, KW_quant, KA, KB, KS, Klb, Kdrop, Kmag = get_lora_parameters(self.k_proj)
# Apply dropout outside autograd.Function (shared mask for Q, K)
X_drop = _apply_dropout(Qdrop, X, self.training)
Q, K = LoRA_QK.apply(
X,
X_drop,
# Q
QW,
Qb,
QW_quant,
QA,
QB,
QS,
Qlb,
Qmag,
# K
KW,
Kb,
KW_quant,
KA,
KB,
KS,
Klb,
Kmag,
# Flags
inplace,
)
return Q, K, K
class LoRA_O(torch.autograd.Function):
"""Optimized LoRA implementation for output projection.

View File

@@ -67,12 +67,165 @@ def find_all_linear_names(model):
return list(lora_module_names)
def _patch_peft_clippable_linear():
"""Patch PEFT to handle Gemma4ClippableLinear which wraps nn.Linear.
Gemma4's vision tower uses ClippableLinear (a thin wrapper around nn.Linear
that clips activations). PEFT doesn't recognise it as a supported layer type,
so we redirect LoRA injection to the inner ``.linear`` child instead.
"""
try:
from transformers.models.gemma4.modeling_gemma4 import (
Gemma4ClippableLinear as _cls,
)
except ImportError:
return
from peft.tuners.lora.model import LoraModel
if getattr(LoraModel, "_axolotl_clippable_patched", False):
return
_orig = LoraModel._create_and_replace
def _patched(
self,
peft_config,
adapter_name,
target,
target_name,
parent,
current_key=None,
**kw,
):
if isinstance(target, _cls):
# Redirect to the inner nn.Linear so PEFT can wrap it normally.
return _orig(
self,
peft_config,
adapter_name,
target.linear,
"linear",
target,
current_key=current_key,
**kw,
)
return _orig(
self,
peft_config,
adapter_name,
target,
target_name,
parent,
current_key=current_key,
**kw,
)
LoraModel._create_and_replace = _patched
LoraModel._axolotl_clippable_patched = True
def _peft_will_auto_convert_target_params(model, lora_config) -> bool:
"""Check whether PEFT will auto-populate target_parameters for this model.
PEFT 0.19's ``convert_peft_config_for_transformers`` rewrites old MoE
``target_modules`` (e.g. ``w1``/``w2``/``w3`` on Mixtral) into
``target_parameters`` (``gate_up_proj``/``down_proj``) because
transformers v5 fused those expert linears into 3D ``nn.Parameter``
tensors. PEFT wraps the resulting 3D params with ``ParamWrapper``,
which rejects ``lora_dropout != 0``. This probe runs the conversion on
a copy of the config so we can detect the situation before
``get_peft_model`` blows up.
"""
if getattr(lora_config, "target_parameters", None):
return False
try:
from peft.utils.transformers_weight_conversion import (
convert_peft_config_for_transformers,
get_model_conversion_mapping,
)
except ImportError:
return False
import copy
probe_cfg = copy.deepcopy(lora_config)
try:
convert_peft_config_for_transformers(
probe_cfg,
model=model,
conversions=get_model_conversion_mapping(model),
)
except Exception: # pylint: disable=broad-except
return False
return bool(getattr(probe_cfg, "target_parameters", None))
def _patch_peft_param_wrapper_dropout():
"""Let PEFT's ``ParamWrapper`` silently accept ``lora_dropout != 0``.
``ParamWrapper`` wraps 3D expert ``nn.Parameter`` tensors and rejects
non-zero dropout because dropout can't be factored out of
``lora_B(lora_A(dropout(x)))`` when the inner op is an expert-indexed
matmul. For mixed configs (attention + MoE experts) this is too
aggressive — the non-expert ``Linear`` LoRA layers *can* apply dropout
and that's usually what the user intended. We pass a copy of the
``LoraConfig`` with ``lora_dropout=0`` only to ``ParamWrapper.__init__``
so it builds with ``nn.Identity`` for its internal dropout slot while
every other layer type still receives the real dropout value.
"""
from peft.tuners.lora.layer import ParamWrapper
if getattr(ParamWrapper, "_axolotl_dropout_patched", False):
return
_orig_init = ParamWrapper.__init__
def _patched_init(
self,
base_layer,
adapter_name,
parameter_name,
config,
*args,
**kwargs,
):
if getattr(config, "lora_dropout", 0):
import copy as _copy
patched_config = _copy.copy(config)
patched_config.lora_dropout = 0.0
return _orig_init(
self,
base_layer,
adapter_name,
parameter_name,
patched_config,
*args,
**kwargs,
)
return _orig_init(
self,
base_layer,
adapter_name,
parameter_name,
config,
*args,
**kwargs,
)
ParamWrapper.__init__ = _patched_init
ParamWrapper._axolotl_dropout_patched = True
def load_lora(
model: PreTrainedModel,
cfg: DictDefault,
inference: bool = False,
config_only: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
_patch_peft_clippable_linear()
lora_target_modules = cfg.lora_target_modules or []
lora_target_parameters = cfg.lora_target_parameters or []
@@ -124,6 +277,7 @@ def load_lora(
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
exclude_modules=getattr(cfg, "lora_exclude_modules", None) or None,
bias="none",
task_type=task_type,
**lora_config_kwargs,
@@ -132,6 +286,20 @@ def load_lora(
if config_only:
return None, lora_config
if getattr(
lora_config, "lora_dropout", 0
) and _peft_will_auto_convert_target_params(model, lora_config):
LOG.warning(
"lora_dropout=%s requested but PEFT will wrap this model's fused "
"MoE expert parameters with ParamWrapper, which cannot apply "
"dropout (the 3D einsum can't factor dropout out of "
"lora_B(lora_A(dropout(x)))). Dropout will still be applied to "
"non-expert LoRA layers (e.g. attention), and expert LoRA layers "
"will use nn.Identity for the dropout slot.",
lora_config.lora_dropout,
)
_patch_peft_param_wrapper_dropout()
rank = int(os.environ.get("LOCAL_RANK", 0))
if (

View File

@@ -547,6 +547,16 @@ class ModelLoader:
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
if self.cfg.model_quantization_config == "FineGrainedFP8Config":
from transformers import FineGrainedFP8Config
fp8_kwargs = {}
if self.cfg.model_quantization_config_kwargs:
fp8_kwargs = self.cfg.model_quantization_config_kwargs
self.model_kwargs["quantization_config"] = FineGrainedFP8Config(
**fp8_kwargs
)
if self.cfg.gptq:
if not hasattr(self.model_config, "quantization_config"):
LOG.warning(
@@ -624,7 +634,14 @@ class ModelLoader:
def _set_attention_config(self):
"""Sample packing uses custom FA2 patch"""
if self.cfg.attn_implementation:
if self.cfg.gemma4_hybrid_attn_impl:
# Load model with flash_attention_2 for sliding window layers;
# global layers will be patched to sdpa post-load.
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = "flash_attention_2"
# Set flash_attention so multipack/sample_packing patches activate
self.cfg.flash_attention = True
elif self.cfg.attn_implementation:
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
elif self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"

View File

@@ -156,6 +156,15 @@ class PatchManager:
# which would clobber any earlier fix.
self._fix_nemotron_h_conversion_mapping()
# Gemma 4 hybrid attention runs here in post-build (NOT post-load):
# the per-layer ``self_attn.config._attn_implementation="sdpa"``
# override needs to walk the raw model tree, which is broken by
# the post-load PEFT wrapping. The accompanying
# ``patch_gemma4_hybrid_mask`` monkey-patch is module-level and
# installation-time-independent, so both halves of the fix live
# cleanly in the same call even though one is instance-scoped
# and the other is module-scoped.
self._apply_gemma_hybrid_attention(model)
self._finalize_moe_expert_quantization(model)
def apply_post_model_load_patches(self, model: PreTrainedModel):
@@ -165,6 +174,83 @@ class PatchManager:
self._apply_lora_kernel_patch(model)
self._apply_scaling_softmax_patch(model)
def _apply_gemma_hybrid_attention(self, model: PreTrainedModel):
"""Apply hybrid attention: FA2 for sliding window layers, SDPA for global layers.
Gemma 4 has global (full_attention) layers with head_dim=512
which exceeds flash attention's supported size. This patch loads the model
with flash_attention_2 for the sliding window layers (head_dim=256), then
gives each global layer a shallow-copied config with _attn_implementation="sdpa".
We also install :func:`axolotl.monkeypatch.gemma4_hybrid_mask.patch_gemma4_hybrid_mask`
which fixes the corresponding mask construction inside
``Gemma4TextModel.forward``. Without it, the per-layer SDPA config
override is not enough — the forward still builds a 2D FA2-format mask
at the model level and the SDPA layers crash at long context lengths
with ``RuntimeError: The expanded size of the tensor ... must match``.
"""
if not self.cfg.gemma4_hybrid_attn_impl:
return
import copy
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
patch_gemma4_hybrid_mask()
# Navigate to the module that has 'layers' - varies by model structure:
# Gemma4ForConditionalGeneration -> .model (Gemma4Model) -> .language_model (Gemma4TextModel) -> .layers
# Gemma4ForCausalLM -> .model (Gemma4TextModel) -> .layers
layers = None
config_source = None
for candidate in [model, getattr(model, "model", None)]:
if candidate is None:
continue
# Check direct layers
if hasattr(candidate, "layers"):
layers = candidate.layers
config_source = candidate
break
# Check language_model.layers (multimodal wrapper)
lang_model = getattr(candidate, "language_model", None)
if lang_model is not None and hasattr(lang_model, "layers"):
layers = lang_model.layers
config_source = lang_model
break
if layers is None:
LOG.warning(
"gemma4_hybrid_attn_impl: could not find decoder layers in model, skipping"
)
return
config = getattr(config_source, "config", self.model_config)
layer_types = getattr(config, "layer_types", None)
if layer_types is None:
LOG.warning(
"gemma4_hybrid_attn_impl: model config has no 'layer_types', skipping. "
"This feature requires a model with mixed sliding/global attention layers."
)
return
patched_count = 0
for layer_idx, layer in enumerate(layers):
if layer_types[layer_idx] != "sliding_attention":
# Global / full_attention layer - use SDPA instead of FA2
attn_module = getattr(layer, "self_attn", None)
if attn_module is not None and hasattr(attn_module, "config"):
sdpa_config = copy.copy(attn_module.config)
sdpa_config._attn_implementation = "sdpa"
attn_module.config = sdpa_config
patched_count += 1
LOG.info(
"gemma4_hybrid_attn_impl: patched %d global layers to use SDPA "
"(remaining %d sliding layers use flash_attention_2)",
patched_count,
len(layers) - patched_count,
)
def _apply_flash_attention_patches(self):
"""Apply patches related to Flash Attention."""
if self.cfg.xformers_attention and self.cfg.sample_packing:
@@ -324,6 +410,21 @@ class PatchManager:
patch_qwen3_5_vlm_flash_attention()
if self.cfg.model_config_type in ("gemma4", "gemma4_text"):
# The fused attn path is now compatible with
# ``gemma4_hybrid_attn_impl``: the kernel handles partial
# rotary (cos.shape[-1] < head_dim) and the fused forward
# mirrors the current ``Gemma4TextAttention.forward`` API
# for shared kv (read from / write to
# ``past_key_values.shared_layers``). See
# ``src/axolotl/kernels/GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``
# for the history.
from axolotl.monkeypatch.models.gemma4.fused_attn import (
patch_gemma4_fused_attn,
)
patch_gemma4_fused_attn()
@staticmethod
def _fix_nemotron_h_conversion_mapping():
"""Remove the spurious embedding→embeddings WeightRenaming from the

View File

@@ -221,14 +221,6 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
# Generic fallback: if tokenizer still has no pad_token, use eos_token
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
LOG.warning(
"Tokenizer does not have a pad_token, falling back to eos_token: %s",
tokenizer.eos_token,
)
additional_special_tokens = None
if cfg.special_tokens:
special_tokens = cfg.special_tokens.to_dict()
@@ -303,6 +295,14 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
{"additional_special_tokens": additional_special_tokens}
)
# Generic fallback: if tokenizer still has no pad_token, use eos_token
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
LOG.warning(
"Tokenizer does not have a pad_token, falling back to eos_token: %s",
tokenizer.eos_token,
)
if is_main_process():
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")

View File

@@ -60,6 +60,13 @@ def fsdp2_load_full_state_dict(
sharded_meta_param.placements,
src_data_rank=0,
)
# Clone the local shard to allow full_tensor to be freed.
if (
sharded_param._local_tensor.untyped_storage().size()
> sharded_param._local_tensor.nelement()
* sharded_param._local_tensor.element_size()
):
sharded_param = sharded_param.clone()
else:
# Non-sharded parameters
if _accelerator.is_main_process:

View File

@@ -86,12 +86,19 @@ def patch_flash_attn_4(model_config=None):
if getattr(fa_utils._lazy_imports, "_axolotl_patched", False):
return
try:
# flash-attn-4>=4.0.0b7
from flash_attn.cute import flash_attn_with_kvcache
except ImportError:
flash_attn_with_kvcache = None
def _patched_lazy_imports(
implementation, attention_wrapper=None, allow_all_kernels=False
):
return (
flash_attn_func,
flash_attn_varlen_func,
flash_attn_with_kvcache,
fa_utils._pad_input,
fa_utils._unpad_input,
)

View File

@@ -0,0 +1,115 @@
"""Hybrid attention mask fix for Gemma 4.
Gemma 4 has full-attention (global) layers with ``head_dim=512`` which
exceeds flash-attention-2's supported size. Axolotl's hybrid-attention
patch in ``patch_manager._apply_gemma_hybrid_attention`` works around
this by forcing ``_attn_implementation="sdpa"`` on each global layer's
``self_attn.config``, leaving sliding-window layers on FA2.
The per-layer config override alone is insufficient, however:
``Gemma4TextModel.forward`` builds a single ``causal_mask_mapping`` dict
using the **model-level** config and passes the mapped mask to each
decoder layer. With FA2 still set at the model level, the ``full_attention``
entry in that mapping is a 2D mask (FA2 format), but SDPA needs a 4D mask.
The global layers then fail with::
RuntimeError: The expanded size of the tensor (S) must match the existing
size (B) at non-singleton dimension 2. Target sizes: [B, H, S, S]. Tensor
sizes: [B, S]
...when the sequence length grows past roughly 7k tokens.
This module fixes the symptom by monkey-patching ``create_causal_mask`` in
``transformers.models.gemma4.modeling_gemma4``'s module namespace — NOT
the original in ``masking_utils``. The wrapper forces
``_attn_implementation="sdpa"`` on a shallow-copied config before calling
through, so the ``full_attention`` mask built inside ``Gemma4TextModel.forward``
is always 4D/SDPA-compatible. ``create_sliding_window_causal_mask`` is left
alone, so sliding-window layers continue to receive FA2-format masks.
The patch is idempotent. Install once per process, before any Gemma 4
forward pass runs.
"""
from __future__ import annotations
import copy
from typing import Any
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
_PATCH_APPLIED = False
def patch_gemma4_hybrid_mask() -> bool:
"""Install the Gemma 4 hybrid-attention mask fix.
Returns ``True`` if the patch was installed (or was already installed),
``False`` if the target module could not be imported (e.g. transformers
version predates Gemma 4) — in which case nothing is done and the
caller can continue unaffected.
"""
global _PATCH_APPLIED
if _PATCH_APPLIED:
return True
try:
from transformers.models.gemma4 import modeling_gemma4
except ImportError:
LOG.debug(
"gemma4_hybrid_mask: transformers.models.gemma4 not importable, "
"skipping. This is fine for non-Gemma4 training."
)
return False
if not hasattr(modeling_gemma4, "create_causal_mask"):
LOG.warning(
"gemma4_hybrid_mask: modeling_gemma4 has no 'create_causal_mask' "
"binding, skipping. Transformers API may have changed."
)
return False
original = modeling_gemma4.create_causal_mask
def hybrid_create_causal_mask(config: Any, *args: Any, **kwargs: Any):
"""Wrapper that forces SDPA format for the full-attention mask.
The global layers were patched to SDPA by
``_apply_gemma_hybrid_attention``, so their mask must be 4D. The
original ``create_causal_mask`` dispatches on
``config._attn_implementation``; we shadow that with a local
override.
"""
sdpa_config = copy.copy(config)
sdpa_config._attn_implementation = "sdpa"
return original(sdpa_config, *args, **kwargs)
# Preserve the original reference on the wrapper for tests / teardown.
hybrid_create_causal_mask._axolotl_original = original # type: ignore[attr-defined]
modeling_gemma4.create_causal_mask = hybrid_create_causal_mask
_PATCH_APPLIED = True
LOG.info(
"gemma4_hybrid_mask: patched modeling_gemma4.create_causal_mask to "
"force SDPA-format masks for full-attention layers"
)
return True
def unpatch_gemma4_hybrid_mask() -> None:
"""Restore the original ``create_causal_mask``. Useful for tests."""
global _PATCH_APPLIED
if not _PATCH_APPLIED:
return
try:
from transformers.models.gemma4 import modeling_gemma4
except ImportError:
_PATCH_APPLIED = False
return
current = modeling_gemma4.create_causal_mask
original = getattr(current, "_axolotl_original", None)
if original is not None:
modeling_gemma4.create_causal_mask = original
_PATCH_APPLIED = False

View File

@@ -16,6 +16,7 @@ from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
apply_lora_o,
apply_lora_qk,
apply_lora_qkv,
)
from axolotl.monkeypatch.utils import detab_code
@@ -111,6 +112,47 @@ QKV_PATCHES = [
else:
key_states = key_states.view(hidden_shape)
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
""".lstrip("\n"),
),
# Gemma4 (transformers >= 5.6): shared_kv_states parameter replaces
# past_key_values.shared_layers, and v_norm added after k_norm.
(
"""
query_states = self.q_proj(hidden_states).view(hidden_shape)
query_states = self.q_norm(query_states)
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
query_states = query_states.transpose(1, 2)
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
# We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
# once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
if self.is_kv_shared_layer:
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
# Device of past layer may be different from current one
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states
""".lstrip("\n"),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape)
query_states = self.q_norm(query_states)
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
query_states = query_states.transpose(1, 2)
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
# We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
# once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
if self.is_kv_shared_layer:
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
# Device of past layer may be different from current one
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
key_states = key_states.view(hidden_shape)
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
""".lstrip("\n"),
),
]
@@ -483,18 +525,24 @@ def apply_lora_kernel_patches(
if cfg.lora_qkv_kernel:
# Query, key, value patching
# Filter out None projections (e.g. Gemma4 v_proj when attention_k_eq_v=True)
proj_names = ["q_proj", "k_proj", "v_proj"]
layer_modules = [
getattr(self_attn, name)
for name in proj_names
if getattr(self_attn, name, None) is not None
]
has_v_proj = getattr(self_attn, "v_proj", None) is not None
proj_names = (
["q_proj", "k_proj", "v_proj"]
if has_v_proj
else ["q_proj", "k_proj"]
)
layer_modules = [getattr(self_attn, name) for name in proj_names]
can_patch_qkv = all(
hasattr(module, "lora_A") for module in layer_modules
)
if can_patch_qkv:
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
if has_v_proj:
self_attn.apply_qkv = types.MethodType(
apply_lora_qkv, self_attn
)
else:
self_attn.apply_qkv = types.MethodType(apply_lora_qk, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA adapters"

View File

@@ -0,0 +1,147 @@
"""
Gemma 4 fused attention monkeypatch.
Replaces the per-layer RMSNorm + RoPE + transpose sequence with fused Triton
kernels, eliminating intermediate tensor allocations from rotate_half / apply_rotary_pos_emb
Usage:
from axolotl.monkeypatch.models.gemma4.fused_attn import patch_gemma4_fused_attn
patch_gemma4_fused_attn()
"""
import logging
from typing import Callable
import torch
logger = logging.getLogger(__name__)
def _make_fused_forward(original_forward):
"""Create a patched forward that uses fused RMSNorm+RoPE kernels."""
from axolotl.kernels.gemma4_fused_rope import (
fused_rms_norm_noscale,
fused_rms_norm_rope,
)
def fused_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor,
attention_mask: torch.Tensor | None,
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],
past_key_values=None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.gemma4.modeling_gemma4 import (
eager_attention_forward,
)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
eps = self.config.rms_norm_eps
cos, sin = position_embeddings
# ---- Projections ----
# Use apply_qkv if present (LoRA kernel patch), otherwise direct proj
has_lora_qkv = hasattr(self, "apply_qkv")
if has_lora_qkv:
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape)
else:
query_states = self.q_proj(hidden_states).view(hidden_shape)
# ---- Q path: fused q_norm + RoPE ----
query_states = fused_rms_norm_rope(
query_states,
self.q_norm.weight,
cos,
sin,
eps=eps,
)
query_states = query_states.transpose(1, 2)
# ---- K/V path ----
if self.is_kv_shared_layer:
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
if has_lora_qkv:
# apply_qkv already computed k/v projections
key_states = key_states.view(hidden_shape)
value_states = (
value_states.view(hidden_shape)
if self.v_proj is not None
else key_states
)
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
value_states = (
self.v_proj(hidden_states).view(hidden_shape)
if self.v_proj is not None
else key_states
)
# Fused k_norm + RoPE
key_states = fused_rms_norm_rope(
key_states,
self.k_norm.weight,
cos,
sin,
eps=eps,
)
key_states = key_states.transpose(1, 2)
# Fused v_norm (no scale, no RoPE)
value_states = fused_rms_norm_noscale(value_states, eps=eps)
value_states = value_states.transpose(1, 2)
if past_key_values is not None and not self.is_kv_shared_layer:
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx
)
if self.store_full_length_kv:
shared_kv_states[self.layer_idx] = key_states, value_states
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[
self.config._attn_implementation
]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=self.attention_dropout if self.training else 0.0,
scaling=self.scaling,
sliding_window=self.sliding_window,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
return fused_forward
def patch_gemma4_fused_attn():
"""
Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels.
"""
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
original_forward = Gemma4TextAttention.forward
Gemma4TextAttention.forward = _make_fused_forward(original_forward)
logger.info(
"Patched Gemma4TextAttention.forward with fused RMSNorm+RoPE Triton kernels"
)

View File

@@ -24,7 +24,15 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
mlp_cls = getattr(module, f"{model_cls_prefix}MLP")
# Some multimodal wrappers (e.g. Gemma 4) name the MLP class
# ``{prefix}TextMLP`` rather than ``{prefix}MLP`` because the
# language-side module is separated from the vision tower. Try
# both names before giving up.
mlp_cls = getattr(
module,
f"{model_cls_prefix}MLP",
None,
) or getattr(module, f"{model_cls_prefix}TextMLP")
if use_original_mlp:
mlp_forward = mlp_cls.forward

View File

@@ -315,6 +315,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
self._validate_eot_and_eos_tokens()
# Pre-cache EOT token IDs to avoid re-encoding on every call
self._eot_token_ids = set()
for token in self.eot_tokens:
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
if len(token_ids) == 1:
self._eot_token_ids.add(token_ids[0])
def _validate_eot_and_eos_tokens(self):
"""
- Validates that EOT tokens (or eos_token) are in the chat_template
@@ -471,6 +478,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
content = turn.get("content")
train_turn = turn.get("training")
train_detail = turn.get("training_detail")
reasoning_train_detail = turn.get("reasoning_training_detail")
LOG.debug(
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
@@ -479,8 +487,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
should_train = None
if train_turn is not None:
should_train = train_turn
elif train_detail is not None:
should_train = bool(train_detail)
elif train_detail is not None or reasoning_train_detail is not None:
should_train = bool(train_detail) or bool(reasoning_train_detail)
else:
should_train = self.train_on_inputs or role in self.roles_to_train
@@ -500,15 +508,26 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
continue
thinking_key = self.prompter.template_thinking_key
has_reasoning = thinking_key and turn.get(thinking_key) is not None
has_any_detail = train_detail or reasoning_train_detail
# When train_detail is present and the turn has reasoning_content,
# use content_only=True so find_turn returns content-only boundaries
# (excluding reasoning_content + template separator tokens).
use_content_only = bool(has_any_detail and has_reasoning)
turn_start_idx, turn_end_idx = self.find_turn(
turns=turns, turn_idx=index, tools=tools
turns=turns,
turn_idx=index,
tools=tools,
content_only=use_content_only,
)
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
if train_detail:
# Block multi-content for now
if not isinstance(content, str):
raise ValueError(
"`train_detail` is not supported when `content` is not a string."
@@ -526,7 +545,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
LOG.debug(
f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}"
)
else:
elif not reasoning_train_detail:
# No per-part detail on either field — train the whole span
labels[turn_start_idx:turn_end_idx] = input_ids[
turn_start_idx:turn_end_idx
]
@@ -534,6 +554,32 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
)
# Handle reasoning_content training_detail separately
if should_train and reasoning_train_detail and has_reasoning:
reasoning_text = turn[thinking_key]
if not isinstance(reasoning_text, str):
raise ValueError(
"`reasoning_training_detail` is not supported when reasoning_content is not a string."
)
reasoning_start, reasoning_end = self.find_turn(
turns=turns,
turn_idx=index,
tools=tools,
reasoning_only=True,
)
if reasoning_start != -1 and reasoning_end != -1:
token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore
reasoning_text, reasoning_train_detail
)
LOG.debug(f"Reasoning token offsets: {token_offsets}")
for i, offset in enumerate(token_offsets):
if offset != IGNORE_TOKEN_ID and reasoning_start + i < len(
input_ids
):
labels[reasoning_start + i] = input_ids[reasoning_start + i]
LOG.debug(f"Labels after processing turn {index}: {labels}")
# Handle special tokens (EOT and EOS)
@@ -593,28 +639,31 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def find_first_eot_token(self, input_ids, start_idx):
"""Find the first EOT token in the input_ids starting from start_idx."""
# Get token IDs for all EOT tokens
eot_token_ids = []
for token in self.eot_tokens:
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
if len(token_ids) != 1:
raise ValueError(
f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config."
)
eot_token_ids.append(token_ids[0]) # Use the last token ID if multiple
# Search for any of the EOT token IDs
# Use pre-cached EOT token IDs (computed once in __init__)
for i in range(start_idx, len(input_ids)):
if input_ids[i] in eot_token_ids:
if input_ids[i] in self._eot_token_ids:
return i
return -1
def find_turn(
self, turns: list[dict], turn_idx: int, tools: list[dict] | None = None
self,
turns: list[dict],
turn_idx: int,
tools: list[dict] | None = None,
content_only: bool = False,
reasoning_only: bool = False,
):
"""
Locate the starting and ending indices of the specified turn in a conversation.
Args:
content_only: If True and the turn has reasoning_content (template_thinking_key),
preserve reasoning_content in the dummy turn so the diff only captures the
content field boundaries. This is needed for correct training_detail alignment
when reasoning_content is present.
reasoning_only: If True, preserve content in the dummy turn and replace
reasoning_content with a dummy, so the diff only captures the
reasoning_content field boundaries.
"""
if turn_idx >= len(turns):
@@ -628,10 +677,26 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
):
return -1, -1
empty_turn = {
"role": turns[turn_idx].get("role"),
"content": "[[dummy_message]]",
}
thinking_key = self.prompter.template_thinking_key
if reasoning_only:
# Keep content as-is, replace reasoning with dummy
empty_turn = {
"role": turns[turn_idx].get("role"),
"content": turns[turn_idx].get("content", ""),
}
if thinking_key and thinking_key in turns[turn_idx]:
empty_turn[thinking_key] = "[[dummy_reasoning]]"
else:
empty_turn = {
"role": turns[turn_idx].get("role"),
"content": "[[dummy_message]]",
}
# When content_only is True, copy reasoning_content to the dummy turn so
# the diff only captures the content field (not reasoning + separator).
if content_only and thinking_key and thinking_key in turns[turn_idx]:
empty_turn[thinking_key] = turns[turn_idx][thinking_key]
# Create conversation versions
turns_with_empty = turns[:turn_idx] + [empty_turn]
@@ -697,6 +762,94 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return start_idx, end_idx
@staticmethod
def _convert_content_parts(
content,
) -> tuple[str, list[dict] | None] | None:
"""Convert list content to concatenated string + optional training_detail.
When content is a list of dicts (content parts), each part can specify:
- ``text``, ``content``, or ``value``: the text string
- ``train`` (bool) or ``weight`` (0/1): per-part training flag
Returns ``(concatenated_text, training_details_or_None)`` if content was
a list, or ``None`` if content was not a list (no conversion needed).
.. note::
**Whitespace at part boundaries matters.** BPE tokenizers prepend
spaces to word tokens (e.g. ``" answer"`` is one token). Always
split BEFORE spaces::
GOOD: ["Let me think...", " The answer is 4."]
BAD: ["Let me think... ", "The answer is 4."]
Tokens that straddle a boundary are conservatively masked.
Newlines typically merge with preceding punctuation (``":\\n"`` is
one token), so keep newlines with the preceding part.
"""
if not isinstance(content, list):
return None
text_parts: list[str] = []
training_details: list[dict] = []
has_explicit_training = False
offset = 0
for part in content:
if isinstance(part, dict):
# Extract text (HF uses "text", also support "content"/"value")
text = (
part.get("text") or part.get("content") or part.get("value") or ""
)
text_parts.append(text)
# Check for per-part training flags
part_train = part.get("train")
part_weight = part.get("weight")
if part_train is not None or part_weight is not None:
has_explicit_training = True
train = (
part_train
if part_train is not None
else (part_weight not in (0, 0.0))
)
else:
train = True # default trainable, gated by turn-level should_train
if text:
training_details.append(
{
"begin_offset": offset,
"end_offset": offset + len(text) - 1,
"train": train,
}
)
offset += len(text)
# Warn about trailing whitespace at boundaries between parts with
# different training flags — this almost always causes token straddling
if has_explicit_training and len(training_details) > 1:
for i in range(len(training_details) - 1):
cur = training_details[i]
nxt = training_details[i + 1]
if cur["train"] != nxt["train"]:
boundary_text = text_parts[i]
if boundary_text and boundary_text[-1] in (" ", "\t"):
LOG.warning(
"Content part %d ends with whitespace at a train/mask boundary. "
"BPE tokenizers typically prepend spaces to word tokens, so "
"the space will merge with the next part's first word and the "
"resulting token will be MASKED (not trained). Move the "
"whitespace to the start of the next content part instead. "
"Part text: %r",
i,
boundary_text[-20:],
)
concatenated = "".join(text_parts)
details = training_details if has_explicit_training else None
return concatenated, details
def get_conversation_thread(self, prompt):
turns = []
@@ -723,6 +876,23 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if training_detail is not None:
turn["training_detail"] = training_detail
# Convert list content/reasoning_content to string + auto-generated
# training_detail. See _convert_content_parts for whitespace guidance.
content_result = self._convert_content_parts(turn.get("content"))
if content_result is not None:
turn["content"] = content_result[0]
if content_result[1] is not None:
turn["training_detail"] = content_result[1]
# Also convert reasoning_content (template_thinking_key) if it's a list
thinking_key = self.prompter.template_thinking_key
if thinking_key and thinking_key in turn:
reasoning_result = self._convert_content_parts(turn[thinking_key])
if reasoning_result is not None:
turn[thinking_key] = reasoning_result[0]
if reasoning_result[1] is not None:
turn["reasoning_training_detail"] = reasoning_result[1]
turns.append(turn)
if self.prompter.drop_system_message and turns[0]["role"] == "system":

View File

@@ -320,6 +320,15 @@ def main(script_args: ScriptArguments):
# --- Active LoRA state (shared across endpoints via closure) ---
active_lora: dict = {"request": None}
# Serializes access to the worker pipe. The underlying
# multiprocessing.Connection is a single full-duplex stream shared
# across all HTTP handlers; concurrent requests interleave bytes on
# the wire and corrupt the pickle framing (seen as
# ``UnpicklingError: pickle data was truncated``). Any endpoint that
# does ``conn.send(...); conn.recv()`` MUST hold this lock across
# the round-trip so only one inflight call at a time per pipe.
worker_pipe_lock = asyncio.Lock()
# ------------------------------------------------------------------
# LoRA-specific endpoints
# ------------------------------------------------------------------
@@ -631,6 +640,150 @@ def main(script_args: ScriptArguments):
},
}
@app.post("/v1/completions")
async def openai_completions(request_body: dict):
"""OpenAI-compatible text-completions endpoint.
Accepts either a string ``prompt`` or a list-of-int
``prompt_token_ids`` (as the text-completions spec allows). Routes
to the internal vLLM generate method with the active LoRA adapter
and returns an OpenAI /v1/completions-shaped response including
per-choice ``prompt_token_ids``, ``generation_token_ids``, and
``generation_log_probs`` for NeMo Gym agents that need raw
tokens + logprobs.
"""
import uuid
prompt_raw = request_body.get("prompt")
temperature = request_body.get("temperature", 1.0)
max_tokens = request_body.get("max_tokens", 512)
top_p = request_body.get("top_p", 1.0)
n = request_body.get("n", 1)
logprobs = request_body.get("logprobs") or 0
stop_token_ids = request_body.get("stop_token_ids") or None
# Accept either a string or a list[int] token id prompt. Lists
# must contain ints only (raise on lists of strings so callers get
# a clear error). Also accept [[int, int, ...]] nesting for the
# rare case callers pass a single-prompt batch.
if (
isinstance(prompt_raw, list)
and prompt_raw
and isinstance(prompt_raw[0], list)
):
prompt_raw = prompt_raw[0]
prompt_dict: dict[str, Any] = {}
if isinstance(prompt_raw, list):
prompt_dict = {"prompt_token_ids": prompt_raw}
elif isinstance(prompt_raw, str):
prompt_dict = {"prompt": prompt_raw}
else:
return {
"error": {
"message": ("prompt must be a string or a list of token ids"),
"type": "invalid_request",
}
}
generation_kwargs: dict[str, Any] = {
"n": n,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"logprobs": logprobs,
}
if stop_token_ids:
generation_kwargs["stop_token_ids"] = stop_token_ids
sampling_params = SamplingParams(
**{k: v for k, v in generation_kwargs.items() if v is not None}
)
chunked = chunk_list([prompt_dict], script_args.data_parallel_size)
# Hold the pipe lock across send+recv — concurrent requests would
# otherwise interleave pickle frames on the worker connection.
async with worker_pipe_lock:
for conn, chunk in zip(connections, chunked, strict=True):
if not chunk:
chunk = [{"prompt": "<placeholder>"}]
kwargs = {
"prompts": chunk,
"sampling_params": sampling_params,
"lora_request": active_lora["request"],
}
conn.send({"type": "call", "method": "generate", "kwargs": kwargs})
loop = asyncio.get_running_loop()
all_outputs = await asyncio.gather(
*(loop.run_in_executor(None, safe_recv, conn) for conn in connections)
)
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
for o in all_outputs:
if isinstance(o, dict) and "error" in o:
raise RuntimeError(f"vLLM worker error: {o['error']}")
all_outputs = list(chain.from_iterable(all_outputs))
if not all_outputs:
return {"choices": [], "model": script_args.model}
choices = []
for i, output in enumerate(all_outputs):
for j, out in enumerate(output.outputs):
text = out.text
# OpenAI-style `logprobs` block for text-completions:
# { "tokens": [...], "token_logprobs": [...] }
lp_block = None
if out.logprobs:
tokens_str: list[str] = []
token_lps: list[float] = []
for step in out.logprobs:
chosen = next(iter(step.values()))
tokens_str.append(getattr(chosen, "decoded_token", "") or "")
token_lps.append(float(chosen.logprob))
lp_block = {
"tokens": tokens_str,
"token_logprobs": token_lps,
}
choice = {
"index": i * n + j,
"text": text,
"finish_reason": "stop"
if out.finish_reason == "stop"
else "length",
"logprobs": lp_block,
# NeMo-Gym / retrace agent extras — preserved on the
# choice so callers with raw-token pipelines don't
# have to re-tokenize.
"prompt_token_ids": output.prompt_token_ids,
"generation_token_ids": list(out.token_ids),
"generation_log_probs": (
[float(next(iter(lp.values())).logprob) for lp in out.logprobs]
if out.logprobs
else []
),
}
choices.append(choice)
prompt_tokens = len(all_outputs[0].prompt_token_ids) if all_outputs else 0
completion_tokens = sum(
len(out.token_ids) for o in all_outputs for out in o.outputs
)
return {
"id": f"cmpl-{uuid.uuid4().hex[:8]}",
"object": "text_completion",
"model": script_args.model,
"choices": choices,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
# --- Weight sync endpoints (legacy fallback, same as TRL) ---
@app.post("/init_communicator/")

View File

@@ -160,29 +160,16 @@ class TelemetryManager:
if not is_main_process():
return False
# Parse relevant env vars
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
do_not_track = os.getenv("DO_NOT_TRACK")
def is_truthy_env(var_name: str) -> bool:
value = os.getenv(var_name)
if value is None:
return False
return value.strip().lower() in ("1", "true")
# Default to enabled (opt-out model)
if axolotl_do_not_track is None or axolotl_do_not_track.lower() not in (
"0",
"1",
"false",
"true",
):
return True
if do_not_track is None:
do_not_track = "0"
# Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled
enabled = axolotl_do_not_track.lower() not in (
"1",
"true",
) and do_not_track.lower() not in ("1", "true")
return enabled
# Telemetry is enabled by default unless either opt-out var is set
return not (
is_truthy_env("AXOLOTL_DO_NOT_TRACK") or is_truthy_env("DO_NOT_TRACK")
)
def _load_whitelist(self) -> dict:
"""Load HuggingFace Hub organization whitelist"""

View File

@@ -36,7 +36,7 @@ from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.freeze import freeze_layers_except, freeze_mm_modules
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.train import determine_last_checkpoint
@@ -114,6 +114,10 @@ def setup_model_and_tokenizer(
):
model.enable_input_require_grads()
# Freeze multimodal modules for text-only training of multimodal models
if cfg.freeze_mm_modules:
freeze_mm_modules(model)
return model, tokenizer, peft_config, processor
@@ -225,6 +229,28 @@ def execute_training(
PLUGIN_MANAGER.post_train(cfg, trainer.model)
def _rename_fsdp_merged_to_adapter(merged_dir: Path):
"""Rename model*.safetensors files to adapter_model* in place.
Also rewrites the index JSON weight_map if sharded output was produced.
"""
for file in sorted(merged_dir.iterdir()):
if file.name.startswith("model") and ".safetensors" in file.name:
file.rename(merged_dir / file.name.replace("model", "adapter_model", 1))
index = merged_dir / "adapter_model.safetensors.index.json"
if index.exists():
data = json.loads(index.read_text(encoding="utf-8"))
if "weight_map" in data:
data["weight_map"] = {
k: v.replace("model", "adapter_model", 1)
for k, v in data["weight_map"].items()
}
index.write_text(
json.dumps(data, indent=2, sort_keys=True) + "\n", encoding="utf-8"
)
def save_trained_model(
cfg: DictDefault,
trainer: Any,
@@ -294,12 +320,17 @@ def save_trained_model(
)
trainer.accelerator.wait_for_everyone()
if trainer.accelerator.is_main_process:
# move all files in merged_path to cfg.output_dir
# FSDP checkpoints for PEFT only contain adapter weights;
# rename model* → adapter_model* so it loads correctly.
is_peft = cfg.adapter and not cfg.relora
if is_peft:
_rename_fsdp_merged_to_adapter(Path(merged_path))
for merged_file in Path(merged_path).iterdir():
if (Path(cfg.output_dir) / merged_file.name).exists():
(Path(cfg.output_dir) / merged_file.name).unlink()
shutil.move(str(merged_file), cfg.output_dir)
shutil.rmtree(merged_path) # remove what should be an empty dir
dest = Path(cfg.output_dir) / merged_file.name
if dest.exists():
dest.unlink()
shutil.move(str(merged_file), dest)
shutil.rmtree(merged_path)
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
# cleanup the FSDP prefix in the model config.json
if trainer.accelerator.is_main_process:

View File

@@ -98,6 +98,56 @@ class SaveModelOnFirstStepCallback(TrainerCallback):
return control
class SkipEvalOnResumeCallback(TrainerCallback):
"""Skip the redundant evaluation that fires when resuming from a checkpoint
whose step aligns with ``eval_steps``.
When HuggingFace Trainer resumes, it restores ``global_step`` from the
checkpoint and immediately triggers ``_maybe_log_save_evaluate`` for that
step. Because the evaluation was already performed during the original
run, repeating it wastes time and pollutes metric logs.
This callback records the ``global_step`` at the start of training (i.e.
the checkpoint step when resuming, or 0 for a fresh run) and suppresses
any evaluation request on that exact step.
"""
def __init__(self):
super().__init__()
self._resume_step: int | None = None
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
):
# ``global_step`` is already restored from the checkpoint at this
# point. For a fresh run it will be 0, so the guard below becomes a
# no-op.
self._resume_step = state.global_step
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
) -> TrainerControl:
if (
self._resume_step
and state.global_step <= self._resume_step
and control.should_evaluate
):
LOG.info(
"Skipping evaluation at step %d (already completed before resume)",
state.global_step,
)
control.should_evaluate = False
return control
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [

View File

@@ -1,7 +1,19 @@
{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0].role == 'system' %}
{{- messages[0].content + '\n\n' }}
{%- if messages[0].content is string %}
{{- messages[0].content + '\n\n' }}
{%- else %}
{%- for part in messages[0].content %}
{%- if part is mapping %}
{%- set system_text = part.get('text') or part.get('content') or part.get('value') %}
{%- if system_text %}{{- system_text }}{%- endif %}
{%- elif part is string %}
{{- part }}
{%- endif %}
{%- endfor %}
{{- '\n\n' }}
{%- endif %}
{%- endif %}
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
@@ -11,7 +23,20 @@
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- if messages[0].content is string %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- else %}
{{- '<|im_start|>system\n' }}
{%- for part in messages[0].content %}
{%- if part is mapping %}
{%- set system_text = part.get('text') or part.get('content') or part.get('value') %}
{%- if system_text %}{{- system_text }}{%- endif %}
{%- elif part is string %}
{{- part }}
{%- endif %}
{%- endfor %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}

View File

@@ -268,6 +268,37 @@ def normalize_config(cfg):
):
cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}
# Gemma4 requires use_reentrant=False for DDP (shared per-layer norms cause
# "marked ready twice" errors with reentrant checkpointing) and
# ddp_find_unused_parameters=True (per_layer_projection LoRA params may not
# receive gradients on every step).
if cfg.model_config_type == "gemma4":
if cfg.gradient_checkpointing:
if cfg.gradient_checkpointing_kwargs is None:
cfg.gradient_checkpointing_kwargs = {}
if cfg.gradient_checkpointing_kwargs.get("use_reentrant") is not False:
LOG.warning(
"Gemma4 requires use_reentrant=False for gradient checkpointing "
"in distributed training. Setting use_reentrant=False."
)
cfg.gradient_checkpointing_kwargs["use_reentrant"] = False
if cfg.ddp and cfg.ddp_find_unused_parameters is None:
if cfg.activation_offloading is True:
# activation_offloading uses checkpoint wrappers that conflict
# with find_unused_parameters (causes "marked ready twice").
# Use freeze_mm_modules instead to eliminate unused params.
LOG.info(
"Gemma4 + DDP + activation_offloading: skipping "
"ddp_find_unused_parameters (use freeze_mm_modules to "
"handle unused vision/audio params)."
)
else:
LOG.warning(
"Gemma4 requires ddp_find_unused_parameters=True for DDP. "
"Auto-enabling."
)
cfg.ddp_find_unused_parameters = True
log_gpu_memory_usage(LOG, "baseline", cfg.device)

View File

@@ -180,6 +180,119 @@ def _drop_long_sequences(
raise ValueError("Unknown RL type")
def _raise_on_long_sequences(
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
) -> bool:
"""Check sequence length and raise ValueError if exceeded.
Used as a filter function for ``excess_length_strategy: raise``.
Args:
sample: Dataset sample to check.
rl: Reinforcement learning type.
tokenizer: Tokenizer for length calculation.
sequence_len: Maximum allowed sequence length.
Returns:
Always True (raises before returning False).
Raises:
ValueError: If any sample exceeds the configured sequence length.
"""
is_valid = _drop_long_sequences(sample, rl, tokenizer, sequence_len)
if not is_valid:
raise ValueError(
f"Sample exceeds configured sequence_len ({sequence_len}). "
"Set `excess_length_strategy: drop` or `excess_length_strategy: truncate` "
"to handle long sequences automatically."
)
return True
def _truncate_long_sequences_rl(
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
) -> dict[str, Any]:
"""Truncate RL samples that exceed maximum sequence length.
For preference datasets (DPO/IPO/ORPO/SIMPO), truncates chosen and rejected
responses to fit within ``sequence_len`` when combined with the prompt.
For KTO, truncates the completion similarly.
GRPO/GDPO/EBFT samples are returned unchanged.
Samples where the prompt alone exceeds ``sequence_len`` cannot be
meaningfully truncated and are returned unchanged. The caller should
follow up with a drop filter to remove them.
Args:
sample: Dataset sample to potentially truncate.
rl: Reinforcement learning type.
tokenizer: Tokenizer for encoding/decoding.
sequence_len: Maximum allowed sequence length.
Returns:
The sample with text fields truncated to fit within sequence_len.
"""
# Fast path: if sample already fits, return unchanged (avoids decode overhead)
if _drop_long_sequences(sample, rl, tokenizer, sequence_len):
return sample
if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}:
if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
):
raise ValueError(
"Prompt, chosen and rejected keys are required for DPO/ORPO datasets"
)
prompt_ids = tokenizer(sample["prompt"], add_special_tokens=False)["input_ids"]
chosen_ids = tokenizer(sample["chosen"], add_special_tokens=False)["input_ids"]
rejected_ids = tokenizer(sample["rejected"], add_special_tokens=False)[
"input_ids"
]
max_response_len = sequence_len - len(prompt_ids)
if max_response_len <= 0:
# Prompt alone exceeds limit; cannot meaningfully truncate.
# Returned unchanged — the follow-up drop filter will remove it.
return sample
updates: dict[str, Any] = {}
if len(chosen_ids) > max_response_len:
updates["chosen"] = tokenizer.decode(
chosen_ids[:max_response_len], skip_special_tokens=False
)
if len(rejected_ids) > max_response_len:
updates["rejected"] = tokenizer.decode(
rejected_ids[:max_response_len], skip_special_tokens=False
)
if updates:
sample = {**sample, **updates}
elif rl is RLType.KTO:
if not (sample.get("prompt") and sample.get("completion")):
raise ValueError("Prompt and completion keys are required for KTO datasets")
prompt_ids = tokenizer(sample["prompt"], add_special_tokens=False)["input_ids"]
completion_ids = tokenizer(sample["completion"], add_special_tokens=False)[
"input_ids"
]
max_completion_len = sequence_len - len(prompt_ids)
if max_completion_len <= 0:
return sample
if len(completion_ids) > max_completion_len:
sample = {
**sample,
"completion": tokenizer.decode(
completion_ids[:max_completion_len], skip_special_tokens=False
),
}
# GRPO/GDPO/EBFT: no truncation needed (responses generated at runtime)
return sample
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
"""Load and process dataset split for RL training.
@@ -243,23 +356,77 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
split_datasets[i] = dataset
if not cfg.skip_prepare_dataset:
drop_long = partial(
_drop_long_sequences,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
if excess_length_strategy == "truncate":
truncate_fn = partial(
_truncate_long_sequences_rl,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].map(
truncate_fn,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Truncating Long Sequences",
)
# Drop samples that could not be truncated (e.g. prompt
# alone exceeds sequence_len)
drop_long = partial(
_drop_long_sequences,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Un-truncatable Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(
f"Dropped {dropped} samples from dataset index {i} "
f"that could not be truncated to fit sequence_len "
f"(prompt alone exceeds limit)"
)
elif excess_length_strategy == "raise":
raise_fn = partial(
_raise_on_long_sequences,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
split_datasets[i] = split_datasets[i].filter(
raise_fn,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Checking Sequence Lengths",
)
else: # "drop" (default)
drop_long = partial(
_drop_long_sequences,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(
f"Dropped {dropped} long samples from dataset index {i}"
)
# Merge datasets
dataset = merge_datasets(split_datasets, cfg)

View File

@@ -10,6 +10,44 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# Top-level module name prefixes that belong to vision/audio/multimodal encoders
# rather than the language backbone. These are matched against the first component
# of each ``named_parameter`` path (e.g. "model.vision_tower." -> "vision_tower").
_MM_MODULE_PREFIXES = (
"vision_tower",
"vision_model",
"vision_encoder",
"embed_vision",
"multi_modal_projector",
"visual",
"audio_tower",
"audio_model",
"embed_audio",
)
def freeze_mm_modules(model):
"""Freeze all vision/audio/multimodal-projector parameters.
Iterates over ``model.named_parameters()`` and sets ``requires_grad = False``
for any parameter whose name contains a known vision/audio module prefix.
This is useful when fine-tuning only the language backbone of a multimodal
model and avoids the need for ``ddp_find_unused_parameters=True``.
"""
frozen_count = 0
for name, param in model.named_parameters():
# Check if any path component matches a vision/audio prefix
parts = name.split(".")
if any(part in _MM_MODULE_PREFIXES for part in parts):
if param.requires_grad:
param.requires_grad = False
frozen_count += 1
if is_main_process():
LOG.debug(f"freeze_mm_modules: froze {name}")
if is_main_process():
LOG.info(f"freeze_mm_modules: froze {frozen_count} vision/audio parameters")
def freeze_layers_except(model, regex_patterns):
"""

View File

@@ -578,6 +578,17 @@ class AxolotlInputConfig(
},
)
freeze_mm_modules: bool | None = Field(
default=None,
json_schema_extra={
"description": "Freeze multimodal encoder parameters (vision, audio, etc.) for "
"text-only training of multimodal models. When True, parameters belonging to "
"vision towers, audio towers, multimodal projectors, and similar non-language "
"modules are frozen (requires_grad=False). This allows DDP training without "
"ddp_find_unused_parameters=True."
},
)
unfrozen_parameters: list[str] | None = Field(
default=None,
json_schema_extra={
@@ -766,6 +777,15 @@ class AxolotlInputConfig(
},
)
gemma4_hybrid_attn_impl: bool | None = Field(
default=None,
json_schema_extra={
"description": "Use hybrid attention for Gemma 4: flash_attention_2 for sliding window layers "
"and sdpa for global (full_attention) layers. Global layers have head_dim=512 which "
"exceeds flash attention's supported size."
},
)
experts_implementation: str | None = Field(
default=None,
json_schema_extra={

View File

@@ -87,9 +87,11 @@ class ModelInputConfig(BaseModel):
json_schema_extra={"description": "Use custom kernels, e.g. MegaBlocks."},
)
model_quantization_config: Literal["Mxfp4Config"] | None = Field(
default=None,
json_schema_extra={"description": "Model loading quantization config"},
model_quantization_config: Literal["Mxfp4Config", "FineGrainedFP8Config"] | None = (
Field(
default=None,
json_schema_extra={"description": "Model loading quantization config"},
)
)
model_quantization_config_kwargs: dict[str, Any] | None = Field(
default=None,

View File

@@ -770,6 +770,88 @@ class RLValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_grpo_batch_size_divisibility(cls, data):
"""Surface GRPO batch-shape mismatches at config-parse time.
TRL's GRPOTrainer requires that the per-step generation batch size be
evenly divisible by ``num_generations`` so that every prompt can be
replicated exactly ``num_generations`` times. The runtime check inside
``GRPOTrainer.__init__`` only fires after the model has been loaded —
too late and too cryptic for the user. We replicate the check here so
the failure is immediate and actionable.
Also enforces:
- ``num_generations >= 2`` (group-relative advantage needs variance)
- ``effective_gbs >= num_generations * world_size`` when capabilities
indicate multiple ranks (each rank needs at least one full group)
"""
if data.get("rl") != "grpo":
return data
trl_cfg = data.get("trl") or {}
num_gen = trl_cfg.get("num_generations")
if num_gen is None:
# TRL's own default is 8 — but if the user didn't set it, we
# don't have enough info to validate anything. Let TRL's own
# init handle the default-vs-batch interaction.
return data
if num_gen < 2:
raise ValueError(
f"GRPO requires `trl.num_generations >= 2` (got {num_gen}). "
"With num_generations=1, every group has zero advantage and "
"the policy never updates."
)
explicit_gbs = trl_cfg.get("generation_batch_size")
if explicit_gbs is not None:
effective_gbs = int(explicit_gbs)
gbs_source = "trl.generation_batch_size"
else:
mb = data.get("micro_batch_size") or 1
ga = data.get("gradient_accumulation_steps") or 1
effective_gbs = int(mb) * int(ga)
gbs_source = f"micro_batch_size ({mb}) * gradient_accumulation_steps ({ga})"
if effective_gbs % num_gen != 0:
# Suggest the smallest GA bump that fixes it for the common case
# where the user hasn't set generation_batch_size explicitly.
hint = ""
if explicit_gbs is None:
from math import gcd
mb_val = int(data.get("micro_batch_size") or 1)
# smallest GA such that mb*GA is a multiple of num_gen
lcm = num_gen * mb_val // gcd(num_gen, mb_val)
suggested_ga = lcm // mb_val
hint = (
f" Smallest fix: set `gradient_accumulation_steps: "
f"{suggested_ga}` (so micro_batch_size * GA = "
f"{mb_val * suggested_ga} is a multiple of {num_gen})."
)
raise ValueError(
f"GRPO: generation batch size must be divisible by "
f"`trl.num_generations`. Got effective_gbs={effective_gbs} "
f"(from {gbs_source}) and num_generations={num_gen}.{hint}"
)
# Multi-rank check: each rank must receive at least one full group
# per step. Without `capabilities` populated yet (mode='before'), we
# fall back to user-set distributed fields.
world_size = (
(data.get("capabilities") or {}).get("n_gpu") or data.get("world_size") or 1
)
if world_size and world_size > 1 and effective_gbs < num_gen * world_size:
raise ValueError(
f"GRPO with world_size={world_size} requires effective_gbs "
f">= num_generations * world_size = {num_gen * world_size}, "
f"got {effective_gbs}. Increase gradient_accumulation_steps "
f"or micro_batch_size."
)
return data
class OptimizationValidationMixin:
"""Validation methods related to optimization and performance."""

View File

@@ -216,5 +216,197 @@ class TestValidateQuantPatchRestore(unittest.TestCase):
self.assertIs(_trainer_module.validate_quantization_for_training, original)
class TestVllmLoraSyncPatch(unittest.TestCase):
"""The ``_generate_single_turn`` patch wires sync_weights to the right place.
These tests exercise the patch-installation branch in isolation. They build
a stub trainer with just enough attributes to look like
``AsyncGRPOTrainer`` for the duration of the relevant code path.
Background — there are two correct behaviors and we historically had a bug
where both modes used the same one:
- Async prefetch ON: the BG generation thread can't safely call
sync_weights mid-rollout. We no-op the stock hook and drive sync from
the main thread via ``_maybe_sync_vllm_weights``.
- Async prefetch OFF: TRL's stock ``_generate_single_turn`` already
calls ``sync_weights`` once per step boundary on the main thread. We
wire that hook directly to ``_sync_lora_adapter`` because
``_maybe_sync_vllm_weights`` short-circuits when async is off.
Before the fix, both modes installed ``lambda: None``, so sync mode never
pushed any LoRA adapter to vLLM and the trainer was a no-op.
"""
@staticmethod
def _make_stub_trainer(*, vllm_lora_sync, async_prefetch):
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOTrainer,
)
class FakeArgs:
pass
args = FakeArgs()
args.vllm_lora_sync = vllm_lora_sync
args.async_prefetch = async_prefetch
class FakeVllmGen:
sync_weights = staticmethod(lambda: None)
model = MagicMock()
# Use object.__new__ so we don't run __init__ (which needs a real
# model, dataset, etc.). We only need the `_generate_single_turn`
# method's patch branch to run, so we set up the minimum state.
trainer = object.__new__(AsyncGRPOTrainer)
trainer.args = args
trainer.use_vllm = True
trainer.vllm_generation = FakeVllmGen()
trainer._patched_sync_weights = False
# Spy on _sync_lora_adapter so we can assert it's the function the
# hook delegates to in sync mode.
trainer._sync_lora_adapter = MagicMock(name="_sync_lora_adapter_spy")
trainer._sync_peft_weights_no_merge = MagicMock(
name="_sync_peft_weights_no_merge_spy"
)
return trainer
@staticmethod
def _run_patch_branch(trainer):
"""Execute just the sync_weights-patching branch in isolation.
We can't easily call the real ``_generate_single_turn`` because it
does a full vLLM generate. Instead we copy the exact branch out of
the source so the test verifies the same logic the trainer runs.
"""
if not getattr(trainer, "_patched_sync_weights", False):
if trainer.use_vllm and hasattr(trainer, "vllm_generation"):
if getattr(trainer.args, "vllm_lora_sync", False):
if getattr(trainer.args, "async_prefetch", False):
trainer.vllm_generation.sync_weights = lambda: None
else:
sync_helper = trainer._sync_lora_adapter
def _lora_filesystem_sync():
sync_helper()
trainer.vllm_generation.sync_weights = _lora_filesystem_sync
trainer._patched_sync_weights = True
def test_sync_mode_with_lora_sync_wires_to_sync_lora_adapter(self):
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
self._run_patch_branch(trainer)
assert trainer._patched_sync_weights is True
# Trigger the patched hook — it must call _sync_lora_adapter.
trainer.vllm_generation.sync_weights()
trainer._sync_lora_adapter.assert_called_once()
def test_async_mode_with_lora_sync_installs_noop_hook(self):
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=True)
self._run_patch_branch(trainer)
assert trainer._patched_sync_weights is True
# Hook must be a no-op so BG-thread generation doesn't fight the
# main-thread optimizer step over the model weights.
trainer.vllm_generation.sync_weights()
trainer._sync_lora_adapter.assert_not_called()
def test_sync_mode_with_lora_sync_does_not_call_during_install(self):
"""Installing the patch should not pre-emptively sync."""
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
self._run_patch_branch(trainer)
# _sync_lora_adapter should only be called when the patched hook
# itself is invoked (e.g., from TRL's _generate_single_turn).
trainer._sync_lora_adapter.assert_not_called()
def test_patch_is_idempotent(self):
trainer = self._make_stub_trainer(vllm_lora_sync=True, async_prefetch=False)
self._run_patch_branch(trainer)
first_hook = trainer.vllm_generation.sync_weights
# Second call must not re-patch (otherwise we'd lose the original).
self._run_patch_branch(trainer)
assert trainer.vllm_generation.sync_weights is first_hook
class TestMaybeSyncVllmWeightsIntervalDefault(unittest.TestCase):
"""``_maybe_sync_vllm_weights`` must not crash when interval is unset.
Before the fix, ``step % self.args.vllm_sync_interval`` would TypeError
on the very first call when ``vllm_sync_interval`` was ``None`` (which
is the default for any config that doesn't explicitly set it). We now
fall back to interval=1 so unset means "sync every step", matching the
behavior of TRL's own ``_generate_single_turn``.
"""
@staticmethod
def _make_stub_trainer(interval, async_prefetch):
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOTrainer,
)
class FakeArgs:
pass
args = FakeArgs()
args.async_prefetch = async_prefetch
args.vllm_sync_interval = interval
args.vllm_lora_sync = True
class FakeState:
global_step = 1
trainer = object.__new__(AsyncGRPOTrainer)
trainer.args = args
trainer.use_vllm = True
trainer.state = FakeState()
trainer._last_synced_step = 0
trainer._sync_lora_adapter = MagicMock(name="sync_spy")
return trainer
def test_interval_none_in_async_mode_does_not_crash(self):
trainer = self._make_stub_trainer(interval=None, async_prefetch=True)
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOTrainer,
)
# Should not raise TypeError — defaults to every-step sync
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
trainer._sync_lora_adapter.assert_called_once()
def test_sync_mode_drives_sync(self):
"""Sync mode must fire ``_sync_lora_adapter`` from ``_maybe_sync_vllm_weights``.
The previous behavior (early return when ``not async_prefetch``)
assumed TRL's stock ``_generate_single_turn`` would handle sync.
That's true for vanilla GRPO but FALSE for NeMo Gym multi-turn
where the data producer bypasses ``_generate_single_turn``
entirely. Without this trigger no sync ever happens and the
trainer becomes a no-op.
"""
trainer = self._make_stub_trainer(interval=1, async_prefetch=False)
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOTrainer,
)
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
trainer._sync_lora_adapter.assert_called_once()
def test_async_mode_with_explicit_interval_respects_modulo(self):
trainer = self._make_stub_trainer(interval=4, async_prefetch=True)
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOTrainer,
)
# global_step=1, interval=4 → 1 % 4 != 0 → no sync
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
trainer._sync_lora_adapter.assert_not_called()
# global_step=4 → 4 % 4 == 0 → sync
trainer.state.global_step = 4
AsyncGRPOTrainer._maybe_sync_vllm_weights(trainer)
trainer._sync_lora_adapter.assert_called_once()
if __name__ == "__main__":
unittest.main()

View File

@@ -54,25 +54,7 @@ except (ImportError, ModuleNotFoundError):
)
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
K_inter, N_hidden = peft_B.shape[0], peft_A.shape[1]
smoe_A = torch.zeros(
rank * num_experts,
K_inter,
device=peft_A.device,
dtype=peft_A.dtype,
)
smoe_B = torch.zeros(
N_hidden,
rank * num_experts,
device=peft_A.device,
dtype=peft_A.dtype,
)
for e in range(num_experts):
s = e * rank
smoe_A[s : s + rank, :] = peft_B_em[:, s : s + rank].T
smoe_B[:, s : s + rank] = peft_A[s : s + rank, :].T
return smoe_A, smoe_B
return peft_A, peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
def _unwrap_experts_lora(experts_module):
return experts_module, None, None
@@ -145,11 +127,7 @@ def scattermoe_lora_B_to_peft(smoe_B, num_experts, rank):
def peft_gate_up_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Convert peft LoRA for gate_up_proj to scattermoe layout.
Both gate_up_proj and down_proj need the A<->B swap because
scattermoe transposes the parameter (W = param.T).
"""
"""Convert peft LoRA for gate_up_proj to scattermoe layout."""
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
@@ -322,14 +300,16 @@ class TestLoRABLayoutConversion:
hidden, inter = 32, 16
scaling = 2.0
peft_A = torch.randn(E * r, hidden)
peft_B = torch.randn(inter, E * r)
peft_A = torch.randn(E * r, inter)
peft_B = torch.randn(hidden, E * r)
A_r = peft_A.reshape(E, r, hidden)
B_r = peft_B.reshape(inter, r, E)
delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling
A_r = peft_A.reshape(E, r, inter)
B_r = peft_B.reshape(hidden, r, E)
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)
assert smoe_A.shape == (E * r, inter)
assert smoe_B.shape == (hidden, E * r)
for e in range(E):
A_e = smoe_A[e * r : (e + 1) * r, :]
B_e = smoe_B[:, e * r : (e + 1) * r]
@@ -342,27 +322,26 @@ class TestLoRABLayoutConversion:
"""Verify gate_up_proj LoRA conversion with non-square dims (Qwen3-like).
gate_up_proj param: [E, 2*inter, hidden].
peft: in_features=2*inter, out_features=hidden.
peft lora_A: [r*E, 2*inter], lora_B: [hidden, r*E].
peft: in_features=hidden, out_features=2*inter.
peft lora_A: [r*E, hidden], lora_B: [2*inter, r*E].
scattermoe W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter.
scattermoe needs: lora_A [r*E, K=hidden], lora_B [N=2*inter, r*E].
Uses non-square dims (hidden=32 != 2*inter=24) to catch A<->B swap bugs.
Uses non-square dims (hidden=32 != 2*inter=24) to catch layout bugs.
"""
E, r = 4, 2
hidden, inter = 32, 12 # 2*inter=24 != hidden=32
scaling = 2.0
# peft assigns: in_features=2*inter, out_features=hidden
peft_A = torch.randn(E * r, 2 * inter) # [r*E, in_features=2*inter]
peft_B = torch.randn(hidden, E * r) # [out_features=hidden, r*E]
# peft assigns: in_features=hidden, out_features=2*inter
peft_A = torch.randn(E * r, hidden) # [r*E, in_features=hidden]
peft_B = torch.randn(2 * inter, E * r) # [out_features=2*inter, r*E]
# peft delta via einsum: "o r e, e r i -> e i o"
A_r = peft_A.reshape(E, r, 2 * inter)
B_r = peft_B.reshape(hidden, r, E)
delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling
# delta_peft[e] has shape [in_features, out_features] = [2*inter, hidden]
A_r = peft_A.reshape(E, r, hidden)
B_r = peft_B.reshape(2 * inter, r, E)
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
# delta_peft[e] has shape [out_features, in_features] = [2*inter, hidden]
# = param[e] shape [2*inter, hidden]
smoe_A, smoe_B = peft_gate_up_lora_to_scattermoe(peft_A, peft_B, E, r)
@@ -422,22 +401,22 @@ class TestPeftLoRAWeightExtraction:
)
# gate_up_proj [E, 2*inter, hidden]
# peft: in_features=2*inter (dim 1), out_features=hidden (dim 2)
# peft: in_features=hidden (last dim), out_features=2*inter (middle dim)
assert trainable[
"base_model.model.moe.experts.base_layer.lora_A.default.weight"
].shape == (E * r, 2 * config.intermediate_size)
assert trainable[
"base_model.model.moe.experts.base_layer.lora_B.default.weight"
].shape == (config.hidden_size, E * r)
# down_proj [E, hidden, inter]
# peft: in_features=hidden (dim 1), out_features=inter (dim 2)
assert trainable[
"base_model.model.moe.experts.lora_A.default.weight"
].shape == (E * r, config.hidden_size)
assert trainable[
"base_model.model.moe.experts.base_layer.lora_B.default.weight"
].shape == (2 * config.intermediate_size, E * r)
# down_proj [E, hidden, inter]
# peft: in_features=inter (last dim), out_features=hidden (middle dim)
assert trainable[
"base_model.model.moe.experts.lora_A.default.weight"
].shape == (E * r, config.intermediate_size)
assert trainable[
"base_model.model.moe.experts.lora_B.default.weight"
].shape == (config.intermediate_size, E * r)
].shape == (config.hidden_size, E * r)
@requires_cuda
def test_peft_forward_runs(self):
@@ -488,27 +467,29 @@ class TestPeftLoRAWeightExtraction:
assert gup_lora is not None, "gate_up_proj LoRA not detected"
assert down_lora is not None, "down_proj LoRA not detected"
# Check shapes (after peft->scattermoe conversion with A<->B swap)
# gate_up_proj W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter
# Check shapes after peft->scattermoe conversion.
# gate_up_proj: peft A [E*r, hidden] / B [2*inter, E*r]
# scattermoe: smoe_A [E*r, hidden], smoe_B [2*inter, E*r]
E, r = config.num_experts, 4
gup_A, gup_B, gup_s = gup_lora
assert gup_A.shape == (E * r, config.hidden_size), (
f"gate_up_proj smoe_A: expected [r*E, K=hidden]={(E * r, config.hidden_size)}, "
f"gate_up_proj smoe_A: expected [r*E, hidden]={(E * r, config.hidden_size)}, "
f"got {gup_A.shape}"
)
assert gup_B.shape == (2 * config.intermediate_size, E * r), (
f"gate_up_proj smoe_B: expected [N=2*inter, r*E]="
f"gate_up_proj smoe_B: expected [2*inter, r*E]="
f"{(2 * config.intermediate_size, E * r)}, got {gup_B.shape}"
)
# down_proj W = param.T = [E, inter, hidden], K=inter, N=hidden
# down_proj: peft A [E*r, inter] / B [hidden, E*r]
# scattermoe: smoe_A [E*r, inter], smoe_B [hidden, E*r]
down_A, down_B, down_s = down_lora
assert down_A.shape == (E * r, config.intermediate_size), (
f"down_proj smoe_A: expected [r*E, K=inter]={(E * r, config.intermediate_size)}, "
f"down_proj smoe_A: expected [r*E, inter]={(E * r, config.intermediate_size)}, "
f"got {down_A.shape}"
)
assert down_B.shape == (config.hidden_size, E * r), (
f"down_proj smoe_B: expected [N=hidden, r*E]={(config.hidden_size, E * r)}, "
f"down_proj smoe_B: expected [hidden, r*E]={(config.hidden_size, E * r)}, "
f"got {down_B.shape}"
)

View File

@@ -361,6 +361,329 @@ class TestPluginDefaults(unittest.TestCase):
assert cfg.dataloader_num_workers == 0
class TestSelectWeightSyncTransport(unittest.TestCase):
"""Pure-logic table tests for ``select_weight_sync_transport``."""
def _caps(self, **kwargs):
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
c = VLLMWeightSyncCapabilities(probed=True)
for k, v in kwargs.items():
setattr(c, k, v)
return c
def test_lora_with_native_endpoint(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(lora_filesystem=True)
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
== "lora_filesystem"
)
def test_lora_with_axolotl_endpoint(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(lora_axolotl=True)
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
== "lora_filesystem"
)
def test_lora_falls_back_to_nccl_when_no_lora_endpoint(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(nccl=True)
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=False)
== "nccl"
)
def test_full_param_prefers_nccl(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(nccl=True, http_full=True)
assert (
select_weight_sync_transport(
caps, has_lora=False, vllm_lora_sync_pref=False
)
== "nccl"
)
def test_full_param_falls_back_to_http(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps(http_full=True)
assert (
select_weight_sync_transport(
caps, has_lora=False, vllm_lora_sync_pref=False
)
== "http_full"
)
def test_full_param_no_routes_returns_none(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps() # all False
assert (
select_weight_sync_transport(
caps, has_lora=False, vllm_lora_sync_pref=False
)
== "none"
)
def test_lora_no_routes_returns_none(self):
from axolotl.integrations.nemo_gym.plugin import select_weight_sync_transport
caps = self._caps()
assert (
select_weight_sync_transport(caps, has_lora=True, vllm_lora_sync_pref=True)
== "none"
)
class TestProbeVllmWeightSync(unittest.TestCase):
"""``probe_vllm_weight_sync`` reads a vLLM ``/openapi.json`` and reports caps."""
def test_stock_vllm_with_lora_enabled(self):
"""Stock ``vllm serve --enable-lora`` exposes only LoRA endpoints."""
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
spec = {
"paths": {
"/v1/models": {"get": {}},
"/v1/load_lora_adapter": {"post": {}},
"/v1/unload_lora_adapter": {"post": {}},
"/v1/completions": {"post": {}},
}
}
with patch("requests.get") as mock_get:
mock_get.return_value.raise_for_status = lambda: None
mock_get.return_value.json = lambda: spec
caps = probe_vllm_weight_sync("http://localhost:8000")
assert caps.probed is True
assert caps.lora_filesystem is True
assert caps.lora_axolotl is False
assert caps.nccl is False
assert caps.http_full is False
def test_axolotl_serve_lora_full_capabilities(self):
"""``axolotl vllm-serve`` exposes NCCL + LoRA + HTTP full sync."""
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
spec = {
"paths": {
"/init_communicator/": {"post": {}},
"/update_named_param/": {"post": {}},
"/batch_update_named_params/": {"post": {}},
"/set_lora_adapter/": {"post": {}},
"/clear_lora_adapter/": {"post": {}},
"/http_update_weights/": {"post": {}},
"/v1/load_lora_adapter": {"post": {}},
}
}
with patch("requests.get") as mock_get:
mock_get.return_value.raise_for_status = lambda: None
mock_get.return_value.json = lambda: spec
caps = probe_vllm_weight_sync("http://localhost:8000")
assert caps.probed is True
assert caps.nccl is True
assert caps.lora_axolotl is True
assert caps.lora_filesystem is True
assert caps.http_full is True
def test_trl_vllm_serve_nccl_only(self):
"""``trl vllm-serve`` exposes NCCL routes but not LoRA filesystem."""
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
spec = {
"paths": {
"/init_communicator/": {"post": {}},
"/update_named_param/": {"post": {}},
"/batch_update_named_params/": {"post": {}},
"/close_communicator/": {"post": {}},
"/generate/": {"post": {}},
}
}
with patch("requests.get") as mock_get:
mock_get.return_value.raise_for_status = lambda: None
mock_get.return_value.json = lambda: spec
caps = probe_vllm_weight_sync("http://localhost:8000")
assert caps.probed is True
assert caps.nccl is True
assert caps.lora_filesystem is False
assert caps.lora_axolotl is False
assert caps.http_full is False
def test_unreachable_server_records_error(self):
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import probe_vllm_weight_sync
with patch("requests.get") as mock_get:
mock_get.side_effect = ConnectionError("Connection refused")
caps = probe_vllm_weight_sync("http://localhost:9999")
assert caps.probed is False
assert caps.probe_error is not None
assert "ConnectionError" in caps.probe_error
assert caps.nccl is False
assert caps.lora_filesystem is False
class TestPluginWeightSyncEnforcement(unittest.TestCase):
"""End-to-end test of post_trainer_create's transport-selection branch.
The plugin used to silently no-op weight sync when ``vllm_lora_sync: false``,
leaving the trainer learning in isolation while vLLM kept serving the
unmodified base model. After the fix:
- LoRA + LoRA-loading endpoint → installs filesystem LoRA sync
- LoRA + only NCCL endpoint → uses NCCL broadcast
- Full FT + NCCL endpoint → uses NCCL broadcast (standard TRL flow)
- Full FT + HTTP endpoint → raises NotImplementedError (step 3)
- No usable transport → raises ValueError with a precise diagnosis
"""
@staticmethod
def _fake_cfg(adapter, vllm_lora_sync):
class FakeTRL:
pass
class FakeCfg:
pass
trl = FakeTRL()
trl.vllm_lora_sync = vllm_lora_sync
trl.vllm_server_host = "127.0.0.1"
trl.vllm_server_port = 8000
cfg = FakeCfg()
cfg.nemo_gym_enabled = True
cfg.nemo_gym_model_name = None
cfg.base_model = "test/model"
cfg.nemo_gym_verify_timeout = 30
cfg.nemo_gym_multi_turn = True
cfg.adapter = adapter
cfg.trl = trl
return cfg
@staticmethod
def _fake_trainer():
class FakeVLLMGen:
sync_weights = staticmethod(lambda: None)
class FakeTrainer:
vllm_generation = FakeVLLMGen()
return FakeTrainer()
@staticmethod
def _caps(**kwargs):
from axolotl.integrations.nemo_gym.plugin import VLLMWeightSyncCapabilities
c = VLLMWeightSyncCapabilities(probed=True)
for k, v in kwargs.items():
setattr(c, k, v)
return c
def test_lora_with_lora_endpoint_installs_filesystem_sync(self):
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps(lora_filesystem=True)
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
trainer = self._fake_trainer()
with (
patch.object(plugin, "_setup_lora_sync") as setup,
patch.object(plugin, "_check_lora_endpoint") as check,
patch.object(plugin, "_wire_multi_turn") as wire,
):
plugin.post_trainer_create(cfg, trainer)
setup.assert_called_once()
check.assert_called_once()
wire.assert_called_once()
def test_lora_with_no_routes_raises_with_lora_specific_message(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps() # all False, but probed
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=False)
trainer = self._fake_trainer()
with self.assertRaises(ValueError) as ctx:
plugin.post_trainer_create(cfg, trainer)
msg = str(ctx.exception)
assert "no-op trainer" in msg
assert "load_lora_adapter" in msg
assert "VLLM_ALLOW_RUNTIME_LORA_UPDATING" in msg
def test_full_finetune_with_nccl_endpoint_uses_nccl(self):
from unittest.mock import patch
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps(nccl=True)
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
trainer = self._fake_trainer()
with patch.object(plugin, "_wire_multi_turn") as wire:
plugin.post_trainer_create(cfg, trainer)
wire.assert_called_once()
def test_full_finetune_with_http_endpoint_not_implemented_yet(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps(http_full=True)
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
trainer = self._fake_trainer()
with self.assertRaises(NotImplementedError) as ctx:
plugin.post_trainer_create(cfg, trainer)
assert "HTTP weight sync" in str(ctx.exception)
def test_full_finetune_with_no_routes_raises_with_full_param_message(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
plugin._vllm_caps = self._caps()
cfg = self._fake_cfg(adapter=None, vllm_lora_sync=False)
trainer = self._fake_trainer()
with self.assertRaises(ValueError) as ctx:
plugin.post_trainer_create(cfg, trainer)
msg = str(ctx.exception)
assert "no-op trainer" in msg
assert "init_communicator" in msg
assert "http_update_weights" in msg
def test_unprobed_caps_raises_with_probe_failure_message(self):
from axolotl.integrations.nemo_gym.plugin import NemoGymPlugin
plugin = NemoGymPlugin()
# Plugin._vllm_caps left as default-None: the post_trainer_create
# branch falls back to a fresh VLLMWeightSyncCapabilities() with
# probed=False, so the error path should mention probing.
cfg = self._fake_cfg(adapter="lora", vllm_lora_sync=True)
trainer = self._fake_trainer()
with self.assertRaises(ValueError) as ctx:
plugin.post_trainer_create(cfg, trainer)
assert "could not probe" in str(ctx.exception)
class TestNemoGymE2E(unittest.TestCase):
"""End-to-end test: data producer → agent (mocked) → parse → tensors → rewards.
@@ -452,19 +775,15 @@ class TestNemoGymE2E(unittest.TestCase):
trainer = self._make_mock_trainer()
producer._trainer = trainer
# Mock the prompt iterator (returns a batch of 1 input)
producer._prompt_iter = iter(
[
[
{
"prompt": [{"role": "user", "content": "Play Wordle!"}],
}
]
]
)
producer._prompt_dl = [
[{"prompt": [{"role": "user", "content": "Play Wordle!"}]}]
# Mock the prompt iterator. RepeatSampler(mini_repeat_count=num_generations)
# pre-expands prompts, so the iterator yields num_generations=2 consecutive
# copies of each unique prompt — one entry per rollout.
_prompt_batch = [
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
{"prompt": [{"role": "user", "content": "Play Wordle!"}]},
]
producer._prompt_iter = iter([_prompt_batch])
producer._prompt_dl = [_prompt_batch]
# Call produce
result = producer.produce(model=MagicMock(), global_step=1)
@@ -530,10 +849,13 @@ class TestNemoGymE2E(unittest.TestCase):
producer._request_timeout = 30
producer._num_generations = 2
producer._trainer = self._make_mock_trainer()
producer._prompt_iter = iter(
[[{"prompt": [{"role": "user", "content": "Play!"}]}]]
)
producer._prompt_dl = [[{"prompt": [{"role": "user", "content": "Play!"}]}]]
# RepeatSampler pre-expands by num_generations=2.
_prompt_batch = [
{"prompt": [{"role": "user", "content": "Play!"}]},
{"prompt": [{"role": "user", "content": "Play!"}]},
]
producer._prompt_iter = iter([_prompt_batch])
producer._prompt_dl = [_prompt_batch]
result = producer.produce(model=MagicMock(), global_step=1)

View File

@@ -21,6 +21,51 @@ from unittest.mock import patch
import pytest
import torch
class TestPeftScatterMoELoRALayout:
"""CPU-only guards for PEFT target_parameters layout conversion."""
def test_peft_layout_keeps_a_and_reorders_b(self):
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_layout import (
peft_lora_to_scattermoe,
)
E, r, K, N = 3, 2, 5, 7
scaling = 2.0
peft_A = torch.randn(E * r, K)
peft_B = torch.randn(N, E * r)
smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)
assert smoe_A is peft_A
assert smoe_A.shape == (E * r, K)
assert smoe_B.shape == (N, E * r)
A_r = peft_A.reshape(E, r, K)
B_r = peft_B.reshape(N, r, E)
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
for e in range(E):
A_e = smoe_A[e * r : (e + 1) * r, :]
B_e = smoe_B[:, e * r : (e + 1) * r]
torch.testing.assert_close(scaling * (B_e @ A_e), delta_peft[e])
def test_swapped_layout_fails_before_kernel_dispatch(self):
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_layout import (
validate_scattermoe_lora_shapes,
)
E, r, K, N = 3, 2, 5, 7
expert_weights = torch.empty(E, K, N)
with pytest.raises(ValueError, match="Invalid ScatterMoE LoRA layout"):
validate_scattermoe_lora_shapes(
expert_weights=expert_weights,
lora_A=torch.empty(E * r, N),
lora_B=torch.empty(K, E * r),
)
# ============================================================================
# 1. KernelsArgs: disable_mlp_kernel validator
# ============================================================================

View File

@@ -0,0 +1,416 @@
"""
Correctness tests for the fused RMSNorm+RoPE Triton kernel.
Tests forward and backward against the reference Gemma4 implementation
(Gemma4RMSNorm + apply_rotary_pos_emb) across both sliding window
(head_dim=256) and global attention (head_dim=512) layer configurations.
"""
import pytest
import torch
torch.manual_seed(42)
# Skip entire module if no CUDA
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
def _reference_norm_rope(x, weight, cos, sin, eps):
"""Reference: separate Gemma4RMSNorm + apply_rotary_pos_emb."""
from transformers.models.gemma4.modeling_gemma4 import (
Gemma4RMSNorm,
apply_rotary_pos_emb,
)
D = x.shape[-1]
norm = Gemma4RMSNorm(D, eps=eps).to(x.device, x.dtype)
norm.weight.data.copy_(weight)
normed = norm(x)
return apply_rotary_pos_emb(normed, cos, sin, unsqueeze_dim=2)
def _reference_norm_noscale(x, eps):
"""Reference: Gemma4RMSNorm with_scale=False."""
from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm
D = x.shape[-1]
norm = Gemma4RMSNorm(D, eps=eps, with_scale=False).to(x.device, x.dtype)
return norm(x)
def _reference_partial_norm_rope(x, weight, cos, sin, eps):
"""Reference: Gemma4RMSNorm over the full head_dim, then stock
``apply_rotary_pos_emb`` over the first ``cos.shape[-1]`` columns, with
the trailing columns passed through unchanged. Mirrors how Llama-style
partial rotary is layered on top of the stock RMSNorm + RoPE primitives.
"""
from transformers.models.gemma4.modeling_gemma4 import (
Gemma4RMSNorm,
apply_rotary_pos_emb,
)
D = x.shape[-1]
n_rot = cos.shape[-1]
norm = Gemma4RMSNorm(D, eps=eps).to(x.device, x.dtype)
norm.weight.data.copy_(weight)
normed = norm(x)
if n_rot == D:
return apply_rotary_pos_emb(normed, cos, sin, unsqueeze_dim=2)
x_rot = normed[..., :n_rot]
x_pass = normed[..., n_rot:]
rotated = apply_rotary_pos_emb(x_rot, cos, sin, unsqueeze_dim=2)
return torch.cat([rotated, x_pass], dim=-1)
@pytest.fixture(
params=[
(2, 64, 32, 256), # sliding window layer shape
(2, 64, 4, 512), # global attention layer shape
(1, 128, 16, 256), # different batch/seq
(1, 1, 1, 8), # minimal size
],
ids=["sliding_256", "global_512", "varied", "minimal"],
)
def shapes(request):
return request.param
@pytest.fixture(params=[torch.bfloat16, torch.float16], ids=["bf16", "fp16"])
def dtype(request):
return request.param
class TestFusedRMSNormRoPEForward:
"""Forward pass correctness."""
def test_matches_reference(self, shapes, dtype):
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
B, S, H, D = shapes
eps = 1e-6
x = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
weight = torch.randn(D, device="cuda", dtype=dtype)
cos = torch.randn(B, S, D, device="cuda", dtype=dtype)
sin = torch.randn(B, S, D, device="cuda", dtype=dtype)
y_ref = _reference_norm_rope(x.clone(), weight, cos, sin, eps)
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
cos_sim = torch.nn.functional.cosine_similarity(
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
)
assert cos_sim > 0.999, f"Forward cosine_sim={cos_sim:.6f}, expected > 0.999"
def test_output_shape(self, shapes):
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
B, S, H, D = shapes
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
y = fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6)
assert y.shape == x.shape
assert y.dtype == x.dtype
class TestFusedRMSNormRoPEBackward:
"""Backward pass correctness via gradient comparison."""
@pytest.mark.parametrize(
"B,S,H,D",
[(2, 64, 32, 256), (2, 64, 4, 512)],
ids=["sliding_256", "global_512"],
)
def test_x_grad_matches_reference(self, B, S, H, D):
from transformers.models.gemma4.modeling_gemma4 import (
Gemma4RMSNorm,
apply_rotary_pos_emb,
)
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
eps = 1e-6
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
# Reference backward
x_ref = torch.randn(
B, S, H, D, device="cuda", dtype=torch.bfloat16, requires_grad=True
)
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
norm_ref.weight.data.copy_(weight_init)
y_ref = apply_rotary_pos_emb(norm_ref(x_ref), cos, sin, unsqueeze_dim=2)
y_ref.sum().backward()
# Fused backward
x_fused = x_ref.data.clone().requires_grad_(True)
w_fused = weight_init.clone().requires_grad_(True)
y_fused = fused_rms_norm_rope(x_fused, w_fused, cos, sin, eps=eps)
y_fused.sum().backward()
cos_sim_x = torch.nn.functional.cosine_similarity(
x_fused.grad.flatten().float(), x_ref.grad.flatten().float(), dim=0
)
assert cos_sim_x > 0.999, f"x grad cosine_sim={cos_sim_x:.6f}, expected > 0.999"
@pytest.mark.parametrize(
"B,S,H,D",
[(2, 64, 32, 256), (2, 64, 4, 512)],
ids=["sliding_256", "global_512"],
)
def test_weight_grad_matches_reference(self, B, S, H, D):
from transformers.models.gemma4.modeling_gemma4 import (
Gemma4RMSNorm,
apply_rotary_pos_emb,
)
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
eps = 1e-6
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
# Reference
x_ref = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
norm_ref.weight = torch.nn.Parameter(weight_init.clone())
apply_rotary_pos_emb(
norm_ref(x_ref), cos, sin, unsqueeze_dim=2
).sum().backward()
# Fused
w_fused = weight_init.clone().requires_grad_(True)
fused_rms_norm_rope(x_ref.clone(), w_fused, cos, sin, eps=eps).sum().backward()
cos_sim_w = torch.nn.functional.cosine_similarity(
w_fused.grad.flatten().float(),
norm_ref.weight.grad.flatten().float(),
dim=0,
)
assert cos_sim_w > 0.995, (
f"weight grad cosine_sim={cos_sim_w:.6f}, expected > 0.995"
)
def test_grad_flows(self):
"""Verify gradients are non-zero and finite."""
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
B, S, H, D = 1, 16, 4, 64
x = torch.randn(
B, S, H, D, device="cuda", dtype=torch.bfloat16, requires_grad=True
)
w = torch.randn(D, device="cuda", dtype=torch.bfloat16, requires_grad=True)
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
y = fused_rms_norm_rope(x, w, cos, sin, eps=1e-6)
y.sum().backward()
assert x.grad is not None, "x.grad is None"
assert w.grad is not None, "w.grad is None"
assert x.grad.isfinite().all(), "x.grad has non-finite values"
assert w.grad.isfinite().all(), "w.grad has non-finite values"
assert x.grad.abs().sum() > 0, "x.grad is all zeros"
assert w.grad.abs().sum() > 0, "w.grad is all zeros"
class TestFusedRMSNormRoPEPartialRotary:
"""Partial-rotary: cos/sin last dim is smaller than head_dim.
Compares against the original primitives (`Gemma4RMSNorm` +
`apply_rotary_pos_emb`) applied to the rotated slice with the trailing
columns passed through. Without the kernel fix this used to crash with
`RuntimeError: shape '[..., D]' is invalid for input of size B*S*n_rot`.
"""
@pytest.mark.parametrize(
"B,S,H,D,n_rot",
[
(2, 16, 4, 64, 32), # half rotary (Llama-style 0.5)
(2, 16, 4, 64, 16), # quarter rotary
(2, 32, 8, 128, 64), # half rotary, larger heads
(1, 8, 2, 256, 64), # 26B sliding-shape, 0.25 partial
(1, 8, 2, 64, 64), # n_rot == D: must still match full-rotary path
],
ids=["half_64", "quarter_64", "half_128", "quarter_256", "full_64"],
)
def test_forward_matches_reference(self, B, S, H, D, n_rot):
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
eps = 1e-6
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
y_ref = _reference_partial_norm_rope(x.clone(), weight, cos, sin, eps)
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
assert y_fused.shape == y_ref.shape == (B, S, H, D)
cos_sim = torch.nn.functional.cosine_similarity(
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
)
assert cos_sim > 0.999, (
f"partial rotary forward cosine_sim={cos_sim:.6f} "
f"(B={B},S={S},H={H},D={D},n_rot={n_rot})"
)
# The pass-through tail must equal the reference RMSNorm output bit-
# for-bit (any deviation would mean the kernel is touching it with a
# spurious rotation, which is the original bug class).
torch.testing.assert_close(
y_fused[..., n_rot:], y_ref[..., n_rot:], rtol=1e-2, atol=1e-2
)
@pytest.mark.parametrize(
"B,S,H,D,n_rot",
[(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)],
ids=["half_64", "quarter_256"],
)
def test_x_grad_matches_reference(self, B, S, H, D, n_rot):
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
eps = 1e-6
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
# Reference backward via the original primitives
x_ref = x_data.clone().requires_grad_(True)
w_ref = weight_init.clone()
y_ref = _reference_partial_norm_rope(x_ref, w_ref, cos, sin, eps)
y_ref.sum().backward()
# Fused backward
x_fused = x_data.clone().requires_grad_(True)
w_fused = weight_init.clone().requires_grad_(True)
y_fused = fused_rms_norm_rope(x_fused, w_fused, cos, sin, eps=eps)
y_fused.sum().backward()
cos_sim_x = torch.nn.functional.cosine_similarity(
x_fused.grad.flatten().float(), x_ref.grad.flatten().float(), dim=0
)
assert cos_sim_x > 0.999, f"partial rotary x grad cosine_sim={cos_sim_x:.6f}"
@pytest.mark.parametrize(
"B,S,H,D,n_rot",
[(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)],
ids=["half_64", "quarter_256"],
)
def test_weight_grad_matches_reference(self, B, S, H, D, n_rot):
from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
eps = 1e-6
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
# Reference: Gemma4RMSNorm whose .weight collects grads, then partial
# rotary applied to the rotated slice.
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
norm_ref.weight = torch.nn.Parameter(weight_init.clone())
normed = norm_ref(x_data)
from transformers.models.gemma4.modeling_gemma4 import apply_rotary_pos_emb
rotated = apply_rotary_pos_emb(normed[..., :n_rot], cos, sin, unsqueeze_dim=2)
y_ref = torch.cat([rotated, normed[..., n_rot:]], dim=-1)
y_ref.sum().backward()
w_fused = weight_init.clone().requires_grad_(True)
fused_rms_norm_rope(x_data.clone(), w_fused, cos, sin, eps=eps).sum().backward()
cos_sim_w = torch.nn.functional.cosine_similarity(
w_fused.grad.flatten().float(),
norm_ref.weight.grad.flatten().float(),
dim=0,
)
assert cos_sim_w > 0.995, (
f"partial rotary weight grad cosine_sim={cos_sim_w:.6f}"
)
def test_full_rotary_unchanged_when_n_rot_equals_d(self):
"""Regression: passing cos/sin with shape == head_dim must still
match the full-rotary reference (the partial-rotary code path must
not perturb the existing full-rotary output)."""
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
B, S, H, D = 2, 16, 4, 64
eps = 1e-6
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
y_ref = _reference_norm_rope(x.clone(), weight, cos, sin, eps)
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
cos_sim = torch.nn.functional.cosine_similarity(
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
)
assert cos_sim > 0.999, f"full-rotary regression cos_sim={cos_sim:.6f}"
def test_validation_errors(self):
"""Wrapper rejects misshaped inputs cleanly (instead of a cryptic
Triton crash deeper in the kernel)."""
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
B, S, H, D = 1, 4, 2, 64
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
w = torch.randn(D, device="cuda", dtype=torch.bfloat16)
# n_rot > head_dim
cos_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16)
sin_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16)
with pytest.raises(ValueError, match="cannot exceed head_dim"):
fused_rms_norm_rope(x, w, cos_big, sin_big)
# cos/sin last-dim mismatch
cos = torch.randn(B, S, 32, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, 16, device="cuda", dtype=torch.bfloat16)
with pytest.raises(ValueError, match="same last dim"):
fused_rms_norm_rope(x, w, cos, sin)
# odd rotary dim
cos_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16)
sin_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16)
with pytest.raises(ValueError, match="must be even"):
fused_rms_norm_rope(x, w, cos_odd, sin_odd)
class TestFusedRMSNormNoScale:
"""Tests for v_norm (RMSNorm without learnable scale)."""
def test_forward_matches_reference(self, shapes, dtype):
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_noscale
B, S, H, D = shapes
eps = 1e-6
x = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
y_ref = _reference_norm_noscale(x.clone(), eps)
y_fused = fused_rms_norm_noscale(x.clone(), eps=eps)
cos_sim = torch.nn.functional.cosine_similarity(
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
)
assert cos_sim > 0.999, f"v_norm cosine_sim={cos_sim:.6f}, expected > 0.999"
def test_backward_flows(self):
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_noscale
x = torch.randn(
1, 16, 4, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True
)
y = fused_rms_norm_noscale(x, eps=1e-6)
y.sum().backward()
assert x.grad is not None
assert x.grad.isfinite().all()
assert x.grad.abs().sum() > 0

View File

@@ -0,0 +1,219 @@
"""Tests for the Gemma 4 fused-attention monkey-patch.
These tests exercise the patched ``Gemma4TextAttention.forward`` against
the stock implementation it replaces. The hybrid Gemma 4 model intentionally
mixes a sliding (`head_dim=32`) layer with a full-attention proportional-rope
layer (`global_head_dim=64`, `partial_rotary_factor=0.25`) so that the
partial-rotary RMSNorm+RoPE path through the fused Triton kernel is
exercised end-to-end (this is the bug originally documented in
``GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``).
The full-model forward also pins that the fused forward keeps accepting
whatever call shape ``Gemma4TextDecoderLayer.forward`` produces in the
installed transformers version — so any future signature drift on
upstream's side trips a clear failure here instead of a confusing
TypeError deep in a training run.
"""
import pytest
import torch
pytestmark = [
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"),
]
pytest.importorskip(
"transformers.models.gemma4",
reason="fused_attn patch only matters when Gemma 4 is available",
)
@pytest.fixture
def restore_gemma4_attention():
"""Snapshot ``Gemma4TextAttention.forward`` and restore after the test
so the monkey-patch does not leak across the suite."""
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
saved = Gemma4TextAttention.forward
yield Gemma4TextAttention
Gemma4TextAttention.forward = saved
def _build_hybrid_config():
"""Tiny hybrid Gemma 4 config: one sliding layer + one full-attention
layer with proportional rope and partial_rotary_factor=0.25. This is
the same shape pattern as ``google/gemma-4-26B-A4B-it`` but small
enough to fit on any GPU."""
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
cfg = Gemma4TextConfig(
vocab_size=128,
hidden_size=64,
intermediate_size=128,
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
head_dim=32,
global_head_dim=64,
layer_types=["sliding_attention", "full_attention"],
sliding_window=64,
max_position_embeddings=2048,
hidden_size_per_layer_input=16,
vocab_size_per_layer_input=128,
rope_parameters={
"sliding_attention": {
"rope_type": "default",
"rope_theta": 10000.0,
},
"full_attention": {
"rope_type": "proportional",
"rope_theta": 1000000.0,
"partial_rotary_factor": 0.25,
},
},
)
cfg._attn_implementation = "sdpa"
return cfg
def _build_model(seed=0):
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel
torch.manual_seed(seed)
cfg = _build_hybrid_config()
return Gemma4TextModel(cfg).cuda().to(torch.bfloat16).eval()
class TestFusedAttnSignature:
"""The fused forward must accept the same call shape as
``Gemma4TextDecoderLayer`` produces in the installed transformers
version. Any signature drift surfaces here as a TypeError."""
def test_decoder_layer_can_call_fused_forward(self, restore_gemma4_attention):
"""Run a model forward that exercises the real
``Gemma4TextDecoderLayer -> Gemma4TextAttention`` call path with
the fused patch installed."""
from axolotl.monkeypatch.models.gemma4.fused_attn import (
patch_gemma4_fused_attn,
)
m = _build_model()
ids = torch.randint(0, 128, (2, 16), device="cuda")
mask = torch.ones(2, 16, dtype=torch.long, device="cuda")
patch_gemma4_fused_attn()
with torch.no_grad():
out = m(input_ids=ids, attention_mask=mask).last_hidden_state
assert out.shape == (2, 16, 64)
assert torch.isfinite(out).all()
class TestFusedAttnPerLayerCorrectness:
"""Compare the patched attention layer to the stock implementation
on a single forward call. This isolates the fused kernel correctness
from cross-layer numerical drift."""
def _run_attention(self, model, layer_idx, hidden_states, position_ids):
"""Call ``Gemma4TextAttention.forward`` (whatever is currently
installed) for one layer and return the output."""
attn = model.layers[layer_idx].self_attn
layer_type = model.config.layer_types[layer_idx]
cos, sin = model.rotary_emb(hidden_states, position_ids, layer_type)
out, _ = attn(
hidden_states=hidden_states,
position_embeddings=(cos, sin),
attention_mask=None,
shared_kv_states={},
)
return out
@pytest.mark.parametrize(
"layer_idx",
[0, 1],
ids=["sliding_head32", "global_head64_proportional"],
)
def test_forward_matches_stock(self, restore_gemma4_attention, layer_idx):
from axolotl.monkeypatch.models.gemma4.fused_attn import (
patch_gemma4_fused_attn,
)
m = _build_model(seed=1)
hs = torch.randn(2, 16, 64, device="cuda", dtype=torch.bfloat16)
pos = torch.arange(16, device="cuda").unsqueeze(0).expand(2, -1)
with torch.no_grad():
ref = self._run_attention(m, layer_idx, hs, pos)
patch_gemma4_fused_attn()
with torch.no_grad():
got = self._run_attention(m, layer_idx, hs, pos)
assert got.shape == ref.shape
assert torch.isfinite(got).all()
cos_sim = torch.nn.functional.cosine_similarity(
ref.flatten().float(), got.flatten().float(), dim=0
)
assert cos_sim > 0.999, (
f"layer {layer_idx} fused vs stock cosine_sim={cos_sim:.6f}"
)
# bf16 precision: a few millis of absolute drift per element is
# acceptable for a Q/K/V projection pipeline. Anything larger is
# a real bug.
torch.testing.assert_close(got, ref, rtol=5e-2, atol=5e-2)
class TestFusedAttnFullModel:
"""End-to-end model forward + backward through both layer types."""
def test_full_forward_matches_stock(self, restore_gemma4_attention):
from axolotl.monkeypatch.models.gemma4.fused_attn import (
patch_gemma4_fused_attn,
)
m = _build_model(seed=2)
ids = torch.randint(0, 128, (2, 32), device="cuda")
mask = torch.ones(2, 32, dtype=torch.long, device="cuda")
with torch.no_grad():
ref = m(input_ids=ids, attention_mask=mask).last_hidden_state.clone()
patch_gemma4_fused_attn()
with torch.no_grad():
got = m(input_ids=ids, attention_mask=mask).last_hidden_state.clone()
assert got.shape == ref.shape
assert torch.isfinite(got).all()
cos_sim = torch.nn.functional.cosine_similarity(
ref.flatten().float(), got.flatten().float(), dim=0
)
# End-to-end through 2 layers (RMSNorm, attention, MLP/MoE) in bf16
# accumulates a small amount of numerical drift; we just want to
# pin that the two paths are computing the same function.
assert cos_sim > 0.999, f"end-to-end cosine_sim={cos_sim:.6f}"
def test_backward_grad_flows_through_fused_path(self, restore_gemma4_attention):
"""Gradients must propagate through the fused RMSNorm+RoPE kernels
for both the sliding and proportional-rope layers."""
from axolotl.monkeypatch.models.gemma4.fused_attn import (
patch_gemma4_fused_attn,
)
m = _build_model(seed=3).train()
patch_gemma4_fused_attn()
ids = torch.randint(0, 128, (2, 16), device="cuda")
mask = torch.ones(2, 16, dtype=torch.long, device="cuda")
out = m(input_ids=ids, attention_mask=mask).last_hidden_state
out.sum().backward()
# Both layers must accumulate gradients on q_norm.weight and
# k_norm.weight — that proves the fused kernel ran the backward.
for i, layer in enumerate(m.layers[:2]):
attn = layer.self_attn
assert attn.q_norm.weight.grad is not None, f"layer {i} q_norm no grad"
assert attn.k_norm.weight.grad is not None, f"layer {i} k_norm no grad"
assert attn.q_norm.weight.grad.isfinite().all()
assert attn.k_norm.weight.grad.isfinite().all()
assert attn.q_norm.weight.grad.abs().sum() > 0
assert attn.k_norm.weight.grad.abs().sum() > 0

View File

@@ -0,0 +1,343 @@
"""Tests for the Gemma 4 hybrid-attention mask fix.
These tests pin the single critical behavior: after installing the patch,
``modeling_gemma4.create_causal_mask`` passes an SDPA-overridden config to
the underlying mask builder regardless of what the caller's config says.
This is what keeps full-attention (head_dim=512) global layers from
crashing at long sequence lengths — they need a 4D SDPA-format mask, not
the 2D FA2 mask that would be built from the model-level config.
The tests use a mocked ``create_causal_mask`` so they don't have to load
a real 26B Gemma 4 model or even have access to its weights. What matters
for the bug fix is which config is handed to the mask factory, not the
factory's actual output.
"""
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
pytest.importorskip(
"transformers.models.gemma4",
reason="gemma4_hybrid_mask patch only matters when Gemma 4 is available",
)
@pytest.fixture
def restore_gemma4_module():
"""Snapshot ``modeling_gemma4.create_causal_mask`` and restore after
each test so patch state doesn't leak across the suite."""
from transformers.models.gemma4 import modeling_gemma4
saved = modeling_gemma4.create_causal_mask
yield modeling_gemma4
modeling_gemma4.create_causal_mask = saved
# Reset the module-level flag so the next test can re-install cleanly.
from axolotl.monkeypatch import gemma4_hybrid_mask
gemma4_hybrid_mask._PATCH_APPLIED = False
def test_patch_replaces_create_causal_mask(restore_gemma4_module):
modeling_gemma4 = restore_gemma4_module
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
original = modeling_gemma4.create_causal_mask
assert patch_gemma4_hybrid_mask() is True
assert modeling_gemma4.create_causal_mask is not original
assert modeling_gemma4.create_causal_mask._axolotl_original is original, (
"patched wrapper must expose the original reference for teardown"
)
def test_patch_is_idempotent(restore_gemma4_module):
modeling_gemma4 = restore_gemma4_module
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
patch_gemma4_hybrid_mask()
wrapper_first = modeling_gemma4.create_causal_mask
# Second call must not re-wrap the already-wrapped function (which
# would leak the original reference through a chain of wrappers).
patch_gemma4_hybrid_mask()
wrapper_second = modeling_gemma4.create_causal_mask
assert wrapper_first is wrapper_second
def test_patched_mask_forces_sdpa_config(restore_gemma4_module):
"""Core invariant: when the patched wrapper is called with a config
that says ``flash_attention_2``, the underlying mask factory receives
a shallow-copied config whose ``_attn_implementation`` is ``"sdpa"``.
Without this, the full-attention global layers get a 2D FA2 mask and
crash at long seq lens with the [B, H, S, S] / [B, S] expand error.
"""
modeling_gemma4 = restore_gemma4_module
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
# Swap in a mock BEFORE installing the patch so the wrapper captures
# it as the "original". The mock records every call so we can inspect
# what config got passed through.
mock_factory = MagicMock(name="create_causal_mask", return_value="mask_4d")
modeling_gemma4.create_causal_mask = mock_factory
patch_gemma4_hybrid_mask()
# Caller-supplied config says FA2 (that's the model-level setting).
caller_config = SimpleNamespace(
_attn_implementation="flash_attention_2",
head_dim=512,
some_other_attr="preserved",
)
result = modeling_gemma4.create_causal_mask(
caller_config,
inputs_embeds=None,
attention_mask=None,
past_key_values=None,
position_ids=None,
)
# Wrapper returned whatever the mock returned — no transformation of
# the result itself.
assert result == "mask_4d"
# The mock was called exactly once with a config whose
# ``_attn_implementation`` is sdpa, NOT the caller's fa2.
assert mock_factory.call_count == 1
(passed_config, *_), passed_kwargs = mock_factory.call_args
assert passed_config._attn_implementation == "sdpa"
# The wrapper must NOT mutate the caller's config in place — other
# mask builders (e.g. create_sliding_window_causal_mask) read from
# the same config and must still see fa2.
assert caller_config._attn_implementation == "flash_attention_2"
# Other attributes on the config must be preserved so the underlying
# factory has everything it needs (head_dim, rope_theta, vocab_size, ...).
assert passed_config.head_dim == 512
assert passed_config.some_other_attr == "preserved"
def test_patched_wrapper_passes_through_all_kwargs(restore_gemma4_module):
"""The wrapper must forward positional + keyword args to the original
unchanged, so transformers' own call-site in Gemma4TextModel.forward
keeps working across minor transformers-version signature drift."""
modeling_gemma4 = restore_gemma4_module
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
mock_factory = MagicMock(return_value="mask")
modeling_gemma4.create_causal_mask = mock_factory
patch_gemma4_hybrid_mask()
caller_config = SimpleNamespace(_attn_implementation="flash_attention_2")
modeling_gemma4.create_causal_mask(
caller_config,
"positional_arg",
inputs_embeds="embeds",
attention_mask="mask_2d",
past_key_values="cache",
position_ids="positions",
or_mask_function="or_fn",
)
args, kwargs = mock_factory.call_args
# First positional (after config override) is preserved.
assert args[1] == "positional_arg"
# All kwargs are forwarded untouched.
assert kwargs["inputs_embeds"] == "embeds"
assert kwargs["attention_mask"] == "mask_2d"
assert kwargs["past_key_values"] == "cache"
assert kwargs["position_ids"] == "positions"
assert kwargs["or_mask_function"] == "or_fn"
def test_unpatch_restores_original(restore_gemma4_module):
modeling_gemma4 = restore_gemma4_module
from axolotl.monkeypatch.gemma4_hybrid_mask import (
patch_gemma4_hybrid_mask,
unpatch_gemma4_hybrid_mask,
)
sentinel = MagicMock(name="original")
modeling_gemma4.create_causal_mask = sentinel
patch_gemma4_hybrid_mask()
assert modeling_gemma4.create_causal_mask is not sentinel
unpatch_gemma4_hybrid_mask()
assert modeling_gemma4.create_causal_mask is sentinel
def test_unpatch_is_safe_without_prior_patch(restore_gemma4_module):
from axolotl.monkeypatch.gemma4_hybrid_mask import unpatch_gemma4_hybrid_mask
# Should be a no-op, no exception.
unpatch_gemma4_hybrid_mask()
def test_sliding_window_mask_builder_is_not_patched(restore_gemma4_module):
"""Only ``create_causal_mask`` is overridden — the sliding-window
factory must remain bound to its original to preserve FA2 masks for
the sliding-attention layers. If we accidentally patch both, the
sliding layers get SDPA format and lose the FA2 speedup."""
modeling_gemma4 = restore_gemma4_module
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
if not hasattr(modeling_gemma4, "create_sliding_window_causal_mask"):
pytest.skip("transformers version has no create_sliding_window_causal_mask")
sliding_before = modeling_gemma4.create_sliding_window_causal_mask
patch_gemma4_hybrid_mask()
sliding_after = modeling_gemma4.create_sliding_window_causal_mask
assert sliding_after is sliding_before
# ---------------------------------------------------------------------------
# Integration tests with a tiny randomly-initialized Gemma4TextModel.
#
# These do NOT load real 26B weights. They build a ~350k-param Gemma 4 text
# model with 2 layers (one sliding, one full_attention), apply the hybrid
# attention path end-to-end, and run a forward pass with a padded
# attention_mask at a long-ish seq len. The invariant we're pinning is that
# the full_attention layer does not crash with the
# "Target sizes: [B, H, S, S]. Tensor sizes: [B, S]"
# error — the exact failure that blew up the Gemma 4 MoE 26B pilot at ~7k
# tokens in the FSDP2 training run.
# ---------------------------------------------------------------------------
def _build_tiny_gemma4_text_model():
"""Return a tiny randomly-initialized Gemma4TextModel with mixed layers."""
import torch
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel
cfg = Gemma4TextConfig(
vocab_size=128,
hidden_size=64,
intermediate_size=128,
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
head_dim=32,
layer_types=["sliding_attention", "full_attention"],
sliding_window=64,
max_position_embeddings=2048,
hidden_size_per_layer_input=16,
vocab_size_per_layer_input=128,
)
# Caller-supplied attn impl simulates the pilot config (fa2 at model
# level). The hybrid patch is what makes this survive long context.
cfg._attn_implementation = "sdpa" # start safe; the test toggles fa2 later
torch.manual_seed(42)
model = Gemma4TextModel(cfg).eval()
return model, cfg
def _apply_hybrid_attn_inline(model, cfg):
"""Replicate what ``patch_manager._apply_gemma_hybrid_attention`` does
to a model, without needing a full PatchManager / pydantic cfg."""
import copy
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
for layer_idx, layer in enumerate(model.layers):
if cfg.layer_types[layer_idx] != "sliding_attention":
attn = getattr(layer, "self_attn", None)
if attn is not None and hasattr(attn, "config"):
sdpa_cfg = copy.copy(attn.config)
sdpa_cfg._attn_implementation = "sdpa"
attn.config = sdpa_cfg
patch_gemma4_hybrid_mask()
def test_tiny_gemma4_long_context_forward_does_not_crash(restore_gemma4_module):
"""End-to-end invariant: with the hybrid attn patch applied, a tiny
Gemma4TextModel runs a forward at long context (1024 tokens) with
real padding in the attention mask, producing the expected output
shape. This exercises the actual code path that crashed the pilot
without needing a real 26B checkpoint or CUDA."""
import torch
model, cfg = _build_tiny_gemma4_text_model()
_apply_hybrid_attn_inline(model, cfg)
B, S = 2, 1024
input_ids = torch.randint(0, cfg.vocab_size, (B, S))
attn_mask = torch.ones(B, S, dtype=torch.long)
# Pad positions in the second row. Without padding, SDPA falls back to
# ``is_causal=True`` with ``mask=None`` — we need a materialized 4D
# mask to exercise the actual bug site.
attn_mask[1, S // 2 :] = 0
with torch.no_grad():
out = model(input_ids=input_ids, attention_mask=attn_mask)
assert out.last_hidden_state.shape == (B, S, cfg.hidden_size)
assert torch.isfinite(out.last_hidden_state).all()
def test_patched_create_causal_mask_returns_4d_for_real_config(
restore_gemma4_module,
):
"""Hit the REAL ``create_causal_mask`` (not a mock) via the wrapper
and verify the returned mask is a 4D tensor — which is the shape the
SDPA-patched global layers need. Without the patch and with a
caller-supplied FA2 config this would return a 2D mask and the layer
would crash at long context."""
import torch
from transformers.cache_utils import DynamicCache
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
patch_gemma4_hybrid_mask()
modeling_gemma4 = restore_gemma4_module
cfg = Gemma4TextConfig(
vocab_size=128,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
head_dim=32,
layer_types=["sliding_attention", "full_attention"],
sliding_window=64,
max_position_embeddings=2048,
hidden_size_per_layer_input=16,
vocab_size_per_layer_input=128,
)
# Simulate the pilot: caller says flash_attention_2, but global layers
# were switched to SDPA per-layer. Without the patch, create_causal_mask
# would return an FA2 2D mask here and the SDPA layer would crash.
cfg._attn_implementation = "flash_attention_2"
B, S = 2, 1024
inputs_embeds = torch.zeros((B, S, cfg.hidden_size), dtype=torch.float32)
attention_mask = torch.ones((B, S), dtype=torch.long)
attention_mask[1, S // 2 :] = 0 # force the 4D materialized path
position_ids = torch.arange(S).unsqueeze(0).expand(B, -1)
past_key_values = DynamicCache(config=cfg)
mask = modeling_gemma4.create_causal_mask(
config=cfg,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
)
assert mask is not None
assert isinstance(mask, torch.Tensor)
assert mask.dim() == 4, (
f"expected a 4D SDPA-format mask, got {mask.dim()}D "
f"shape={tuple(mask.shape)}. The full_attention global layers need "
"this shape or they crash at long context."
)
assert mask.shape[0] == B
assert mask.shape[-1] == S
assert mask.shape[-2] == S
# Caller's config must be untouched — other code paths still read it.
assert cfg._attn_implementation == "flash_attention_2"

View File

@@ -916,6 +916,235 @@ class TestChatTemplateConfigurations:
LOG.debug(f"Final labels: {labels}")
LOG.debug(f"Final input_ids: {input_ids}")
@enable_hf_offline
def test_content_parts_training(
self,
tokenizer,
chat_template,
chat_template_jinja,
eos_token,
request,
):
LOG.info("Testing with content as list of parts with per-part training")
tokenizer, chat_template_jinja = self.setup_tokenizer(
tokenizer, chat_template, chat_template_jinja, eos_token, request
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
# Dataset where assistant content is a list of parts with per-part training
conversation = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are an AI assistant."},
],
},
{
"role": "user",
"content": [
{"type": "text", "text": "What is 2+2?"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me think...", "train": False},
{"type": "text", "text": "The answer is 4.", "train": True},
],
},
]
dataset = Dataset.from_dict({"messages": [conversation]})
res = strategy.tokenize_prompt(dataset[0])
turns = strategy.get_conversation_thread(dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Find the assistant turn (last turn)
assistant_turn_idx = len(turns) - 1
start_idx, end_idx = strategy.find_turn(
turns=turns, turn_idx=assistant_turn_idx
)
assert start_idx != -1 and end_idx != -1, (
"Could not find assistant turn boundaries"
)
decoded = tokenizer.decode(input_ids[start_idx:end_idx])
LOG.debug(f"Assistant turn decoded: {decoded}")
# Tokenize each part separately to find their boundaries
part1_text = "Let me think..."
part2_text = "The answer is 4."
# Verify the concatenated content is in the decoded output
assert part1_text in decoded, (
f"Part 1 '{part1_text}' not found in decoded: {decoded}"
)
assert part2_text in decoded, (
f"Part 2 '{part2_text}' not found in decoded: {decoded}"
)
# Verify that part1 tokens (train=False) are masked
# and part2 tokens (train=True) are labeled
turn_labels = labels[start_idx:end_idx]
# Find where part2 starts in the token sequence
part1_tokens = tokenizer(part1_text, add_special_tokens=False)["input_ids"]
part2_tokens = tokenizer(part2_text, add_special_tokens=False)["input_ids"]
# The first part should be masked (all IGNORE_TOKEN_ID)
# Due to token boundary alignment, check that at least the interior tokens
# of part1 are masked
assert any(label == IGNORE_TOKEN_ID for label in turn_labels), (
f"Expected some masked labels for train=False part, but got {turn_labels}"
)
# The second part should be trained (not IGNORE_TOKEN_ID)
assert any(label != IGNORE_TOKEN_ID for label in turn_labels), (
f"Expected some trained labels for train=True part, but got {turn_labels}"
)
# More precise check: first N tokens should be masked, last M tokens should be trained
# where N ~ len(part1_tokens) and M ~ len(part2_tokens)
# Allow for token boundary effects at the boundary
num_masked = sum(1 for label in turn_labels if label == IGNORE_TOKEN_ID)
num_trained = sum(1 for label in turn_labels if label != IGNORE_TOKEN_ID)
LOG.debug(f"Turn labels: {turn_labels}")
LOG.debug(f"Masked tokens: {num_masked}, Trained tokens: {num_trained}")
LOG.debug(
f"Part1 tokens: {len(part1_tokens)}, Part2 tokens: {len(part2_tokens)}"
)
# The number of masked tokens should be roughly the size of part1
# and the number of trained tokens should be roughly the size of part2
assert num_masked > 0, "Expected masked tokens for the train=False part"
assert num_trained > 0, "Expected trained tokens for the train=True part"
@enable_hf_offline
def test_content_parts_with_weight(
self,
tokenizer,
chat_template,
chat_template_jinja,
eos_token,
request,
):
LOG.info("Testing with content parts using weight field")
tokenizer, chat_template_jinja = self.setup_tokenizer(
tokenizer, chat_template, chat_template_jinja, eos_token, request
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
# Dataset using weight instead of train
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Thinking step by step: ", "weight": 0},
{"type": "text", "text": "Hello! How can I help?", "weight": 1},
],
},
]
dataset = Dataset.from_dict({"messages": [conversation]})
res = strategy.tokenize_prompt(dataset[0])
labels = res["labels"]
# There should be both masked and trained labels
has_masked = any(label == IGNORE_TOKEN_ID for label in labels)
has_trained = any(label != IGNORE_TOKEN_ID for label in labels)
assert has_masked, "Expected masked tokens (weight=0 part + user turn)"
assert has_trained, "Expected trained tokens (weight=1 part)"
@enable_hf_offline
def test_content_parts_string_passthrough(
self,
tokenizer,
chat_template,
chat_template_jinja,
eos_token,
request,
):
LOG.info("Testing that string content still works alongside list content")
tokenizer, chat_template_jinja = self.setup_tokenizer(
tokenizer, chat_template, chat_template_jinja, eos_token, request
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
# All list content in the conversation
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is 2+2?"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "The answer is 4.", "train": True},
],
},
]
dataset = Dataset.from_dict({"messages": [conversation]})
res = strategy.tokenize_prompt(dataset[0])
# Should tokenize without errors
assert "input_ids" in res
assert "labels" in res
assert len(res["input_ids"]) > 0
def test_get_chat_template_variables(
self, tokenizer, chat_template, chat_template_jinja, eos_token, request
):
@@ -1428,3 +1657,250 @@ class TestChatTemplateToolCalling:
assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
f"Assistant turn {i} should be unmasked"
)
class TestChatTemplateReasoningContent:
"""
Test class for reasoning_content with content parts.
"""
@enable_hf_offline
def test_reasoning_content_with_content_parts(self, qwen3_tokenizer):
"""Test that reasoning_content as string + content as list parts works correctly.
Content training_detail offsets should align with content-only boundaries."""
LOG.info("Testing reasoning_content with content parts on qwen3")
tokenizer = deepcopy(qwen3_tokenizer)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template("qwen3"),
message_property_mappings={
"role": "role",
"content": "content",
"reasoning_content": "reasoning_content",
},
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
# reasoning_content is a plain string, content is list with per-part training
conversation = [
{
"role": "user",
"content": [{"type": "text", "text": "What is 2+2?"}],
},
{
"role": "assistant",
"reasoning_content": "Step 1: 2+2=4",
"content": [
{"type": "text", "text": "The answer is 4.", "train": True},
],
},
]
dataset = Dataset.from_dict({"messages": [conversation]})
res = strategy.tokenize_prompt(dataset[0])
turns = strategy.get_conversation_thread(dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Find the assistant turn
assistant_idx = 1
start_idx, end_idx = strategy.find_turn(
turns=turns, turn_idx=assistant_idx, content_only=True
)
assert start_idx != -1 and end_idx != -1, (
"Could not find assistant content boundaries"
)
# The content-only span should contain "The answer is 4." but NOT "Step 1: 2+2=4"
decoded_span = tokenizer.decode(input_ids[start_idx:end_idx])
assert "The answer is 4." in decoded_span, (
f"Content not found in span: {decoded_span}"
)
assert "Step 1" not in decoded_span, (
f"Reasoning should not be in content-only span: {decoded_span}"
)
# Verify that content tokens are trained
content_labels = labels[start_idx:end_idx]
assert any(label != IGNORE_TOKEN_ID for label in content_labels), (
f"Expected trained labels in content span, got {content_labels}"
)
@enable_hf_offline
def test_reasoning_content_per_part_masking(self, qwen3_tokenizer):
"""Test masking incorrect reasoning while training on self-correction.
This is the core use case: mask out wrong thoughts, train on corrections."""
LOG.info("Testing reasoning_content per-part masking on qwen3")
tokenizer = deepcopy(qwen3_tokenizer)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template("qwen3"),
message_property_mappings={
"role": "role",
"content": "content",
"reasoning_content": "reasoning_content",
},
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
# Reasoning has wrong step (masked) then self-correction (trained)
conversation = [
{
"role": "user",
"content": [{"type": "text", "text": "What is 2+2?"}],
},
{
"role": "assistant",
"reasoning_content": [
{"type": "text", "text": "Hmm maybe 2+2=5.", "train": False},
{"type": "text", "text": " Wait no, 2+2=4.", "train": True},
],
"content": [
{"type": "text", "text": "The answer is 4.", "train": True},
],
},
]
dataset = Dataset.from_dict({"messages": [conversation]})
res = strategy.tokenize_prompt(dataset[0])
turns = strategy.get_conversation_thread(dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Find reasoning boundaries
reasoning_start, reasoning_end = strategy.find_turn(
turns=turns, turn_idx=1, reasoning_only=True
)
assert reasoning_start != -1 and reasoning_end != -1, (
"Could not find reasoning boundaries"
)
decoded_reasoning = tokenizer.decode(input_ids[reasoning_start:reasoning_end])
LOG.debug(f"Reasoning span: {decoded_reasoning!r}")
assert "2+2=5" in decoded_reasoning, (
f"Wrong step not in reasoning span: {decoded_reasoning}"
)
assert "2+2=4" in decoded_reasoning, (
f"Correction not in reasoning span: {decoded_reasoning}"
)
# Verify reasoning labels have both masked and trained tokens
reasoning_labels = labels[reasoning_start:reasoning_end]
reasoning_ids = input_ids[reasoning_start:reasoning_end]
# Decode only the trained tokens — should be exactly the self-correction
trained_ids = [
tid
for tid, lab in zip(reasoning_ids, reasoning_labels, strict=True)
if lab != IGNORE_TOKEN_ID
]
trained_text = tokenizer.decode(trained_ids)
assert trained_text.strip() == "Wait no, 2+2=4.", (
f"Expected trained reasoning to be 'Wait no, 2+2=4.', got: {trained_text!r}"
)
# Decode only the masked tokens — should be exactly the incorrect step
masked_ids = [
tid
for tid, lab in zip(reasoning_ids, reasoning_labels, strict=True)
if lab == IGNORE_TOKEN_ID
]
masked_text = tokenizer.decode(masked_ids)
assert masked_text.strip() == "Hmm maybe 2+2=5.", (
f"Expected masked reasoning to be 'Hmm maybe 2+2=5.', got: {masked_text!r}"
)
# Find content boundaries
content_start, content_end = strategy.find_turn(
turns=turns, turn_idx=1, content_only=True
)
assert content_start != -1 and content_end != -1, (
"Could not find content boundaries"
)
# Content should be fully trained — decode trained tokens to verify
content_labels = labels[content_start:content_end]
content_ids = input_ids[content_start:content_end]
content_trained_ids = [
tid
for tid, lab in zip(content_ids, content_labels, strict=True)
if lab != IGNORE_TOKEN_ID
]
content_trained_text = tokenizer.decode(content_trained_ids)
assert "The answer is 4." in content_trained_text, (
f"Expected 'The answer is 4.' in trained content tokens, "
f"got: {content_trained_text!r}"
)
assert all(label != IGNORE_TOKEN_ID for label in content_labels), (
f"Expected all content labels trained, got {content_labels}"
)
@enable_hf_offline
def test_reasoning_content_as_list_no_training_flags(self, qwen3_tokenizer):
"""Test that reasoning_content as list without training flags still works."""
LOG.info("Testing reasoning_content as list without training flags on qwen3")
tokenizer = deepcopy(qwen3_tokenizer)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template("qwen3"),
message_property_mappings={
"role": "role",
"content": "content",
"reasoning_content": "reasoning_content",
},
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
# Both as lists, no per-part training flags
conversation = [
{
"role": "user",
"content": [{"type": "text", "text": "What is 2+2?"}],
},
{
"role": "assistant",
"reasoning_content": [
{"type": "text", "text": "Step 1: addition."},
{"type": "text", "text": " Step 2: 2+2=4."},
],
"content": [
{"type": "text", "text": "The answer is 4."},
],
},
]
dataset = Dataset.from_dict({"messages": [conversation]})
res = strategy.tokenize_prompt(dataset[0])
# Should tokenize without errors
assert "input_ids" in res
assert "labels" in res
assert len(res["input_ids"]) > 0
# Verify the full output contains both reasoning and content
full_text = tokenizer.decode(res["input_ids"])
assert "Step 1: addition." in full_text
assert "Step 2: 2+2=4." in full_text
assert "The answer is 4." in full_text

View File

@@ -65,47 +65,57 @@ def test_singleton_instance(telemetry_manager_class):
assert telemetry_manager_class.get_instance() is first
def test_telemetry_enabled_by_default(telemetry_manager_class):
"""Test that telemetry is enabled by default (opt-out)"""
with (
patch.dict(os.environ, {"RANK": "0"}, clear=True),
patch("time.sleep"),
patch("logging.Logger.info"),
class TestTelemetryOptOut:
"""
Telemetry is opt-out: enabled by default, disabled by AXOLOTL_DO_NOT_TRACK
or DO_NOT_TRACK. Each env var is checked independently — setting either one
to a truthy value ("1" or "true") disables telemetry.
The parametrized table below is the source of truth for expected behavior.
"""
# fmt: off
# AXOLOTL_DO_NOT_TRACK DO_NOT_TRACK expected
@pytest.mark.parametrize("axolotl_dnt, dnt, expected", [
# --- Neither var set: telemetry ON ---
(None, None, True),
# --- Only AXOLOTL_DO_NOT_TRACK set ---
("0", None, True), # explicit opt-in
("false", None, True), # explicit opt-in
("1", None, False), # opt-out
("true", None, False), # opt-out
(" 1 ", None, False), # whitespace-padded opt-out
# --- Only DO_NOT_TRACK set (was broken before fix) ---
(None, "0", True), # explicit opt-in
(None, "false", True), # explicit opt-in
(None, "1", False), # opt-out
(None, "true", False), # opt-out
# --- Both set: either truthy → disabled ---
("0", "1", False), # DO_NOT_TRACK wins
("1", "0", False), # AXOLOTL_DO_NOT_TRACK wins
("1", "1", False), # both opt-out
("0", "0", True), # both opt-in
])
# fmt: on
def test_do_not_track_env_vars(
self, telemetry_manager_class, axolotl_dnt, dnt, expected
):
manager = telemetry_manager_class()
assert manager.enabled
env = {"RANK": "0"}
if axolotl_dnt is not None:
env["AXOLOTL_DO_NOT_TRACK"] = axolotl_dnt
if dnt is not None:
env["DO_NOT_TRACK"] = dnt
def test_telemetry_enabled_with_explicit_opt_in(telemetry_manager_class):
"""Test that telemetry is enabled when AXOLOTL_DO_NOT_TRACK=0"""
with (
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "0"}),
patch("time.sleep"),
):
manager = telemetry_manager_class()
assert manager.enabled
def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class):
"""Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1"""
with (
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1", "RANK": "0"}),
patch("time.sleep"),
):
manager = telemetry_manager_class()
assert not manager.enabled
def test_telemetry_disabled_with_do_not_track(telemetry_manager_class):
"""Test that telemetry is disabled when DO_NOT_TRACK=1"""
with (
patch.dict(
os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "DO_NOT_TRACK": "1", "RANK": "0"}
),
patch("time.sleep"),
):
manager = telemetry_manager_class()
assert not manager.enabled
with (
patch.dict(os.environ, env, clear=True),
patch("time.sleep"),
patch("logging.Logger.info"),
):
manager = telemetry_manager_class()
assert manager.enabled is expected
def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):

View File

@@ -0,0 +1,63 @@
"""Tests for SkipEvalOnResumeCallback."""
from unittest.mock import MagicMock
from transformers import TrainerControl, TrainerState, TrainingArguments
from axolotl.utils.callbacks import SkipEvalOnResumeCallback
class TestSkipEvalOnResumeCallback:
"""Tests for skipping redundant evaluation on checkpoint resume."""
@staticmethod
def _make_state(global_step: int) -> TrainerState:
state = MagicMock(spec=TrainerState)
state.global_step = global_step
return state
def test_suppresses_eval_at_resume_step(self):
cb = SkipEvalOnResumeCallback()
args = MagicMock(spec=TrainingArguments)
state = self._make_state(20)
control = TrainerControl(should_evaluate=False)
# Simulate on_train_begin at checkpoint-20
cb.on_train_begin(args, state, control)
# Trainer sets should_evaluate = True for step 20
control.should_evaluate = True
result = cb.on_step_end(args, state, control)
assert result.should_evaluate is False
def test_allows_eval_after_resume_step(self):
cb = SkipEvalOnResumeCallback()
args = MagicMock(spec=TrainingArguments)
state = self._make_state(20)
control = TrainerControl(should_evaluate=False)
cb.on_train_begin(args, state, control)
# Advance past the resume point
state.global_step = 30
control.should_evaluate = True
result = cb.on_step_end(args, state, control)
assert result.should_evaluate is True
def test_noop_on_fresh_run(self):
cb = SkipEvalOnResumeCallback()
args = MagicMock(spec=TrainingArguments)
state = self._make_state(0)
control = TrainerControl(should_evaluate=False)
# Fresh run: global_step starts at 0
cb.on_train_begin(args, state, control)
# Even if eval triggers at step 0 (unlikely but defensive)
state.global_step = 10
control.should_evaluate = True
result = cb.on_step_end(args, state, control)
assert result.should_evaluate is True

292
tests/utils/data/test_rl.py Normal file
View File

@@ -0,0 +1,292 @@
"""
Unit tests for RL data utility functions (excess_length_strategy support).
"""
import unittest
from axolotl.utils.data.rl import (
_drop_long_sequences,
_raise_on_long_sequences,
_truncate_long_sequences_rl,
)
from axolotl.utils.schemas.enums import RLType
class _FakeTokenizer:
"""Simple whitespace tokenizer for testing length calculations."""
def __call__(self, text, add_special_tokens=True): # noqa: ARG002
tokens = text.split()
return {"input_ids": list(range(len(tokens)))}
def decode(self, token_ids, skip_special_tokens=True): # noqa: ARG002
# Each token id maps to a placeholder word; length is what matters.
return " ".join(f"w{i}" for i in range(len(token_ids)))
def _make_dpo_sample(prompt_len: int, chosen_len: int, rejected_len: int):
"""Create a DPO sample with specified word counts."""
return {
"prompt": " ".join(f"p{i}" for i in range(prompt_len)),
"chosen": " ".join(f"c{i}" for i in range(chosen_len)),
"rejected": " ".join(f"r{i}" for i in range(rejected_len)),
}
def _make_kto_sample(prompt_len: int, completion_len: int):
"""Create a KTO sample with specified word counts."""
return {
"prompt": " ".join(f"p{i}" for i in range(prompt_len)),
"completion": " ".join(f"c{i}" for i in range(completion_len)),
}
class TestDropLongSequences(unittest.TestCase):
"""Tests for the existing _drop_long_sequences filter function."""
def setUp(self):
self.tokenizer = _FakeTokenizer()
def test_dpo_keeps_short_samples(self):
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
result = _drop_long_sequences(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
self.assertTrue(result)
def test_dpo_drops_long_chosen(self):
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=2)
result = _drop_long_sequences(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
self.assertFalse(result)
def test_dpo_drops_long_rejected(self):
sample = _make_dpo_sample(prompt_len=5, chosen_len=2, rejected_len=10)
result = _drop_long_sequences(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
self.assertFalse(result)
def test_kto_keeps_short_samples(self):
sample = _make_kto_sample(prompt_len=3, completion_len=2)
result = _drop_long_sequences(
sample, RLType.KTO, self.tokenizer, sequence_len=10
)
self.assertTrue(result)
def test_kto_drops_long_completion(self):
sample = _make_kto_sample(prompt_len=5, completion_len=10)
result = _drop_long_sequences(
sample, RLType.KTO, self.tokenizer, sequence_len=10
)
self.assertFalse(result)
def test_grpo_always_keeps(self):
sample = {"prompt": "a " * 100}
result = _drop_long_sequences(
sample, RLType.GRPO, self.tokenizer, sequence_len=5
)
self.assertTrue(result)
def test_dpo_missing_keys_raises(self):
with self.assertRaises(ValueError):
_drop_long_sequences({"prompt": "hi"}, RLType.DPO, self.tokenizer, 10)
def test_kto_missing_keys_raises(self):
with self.assertRaises(ValueError):
_drop_long_sequences({"prompt": "hi"}, RLType.KTO, self.tokenizer, 10)
def test_ipo_uses_dpo_logic(self):
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=2)
result = _drop_long_sequences(
sample, RLType.IPO, self.tokenizer, sequence_len=10
)
self.assertFalse(result)
def test_orpo_uses_dpo_logic(self):
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
result = _drop_long_sequences(
sample, RLType.ORPO, self.tokenizer, sequence_len=10
)
self.assertTrue(result)
def test_boundary_length_kept(self):
"""Samples exactly at sequence_len should be kept."""
sample = _make_dpo_sample(prompt_len=5, chosen_len=5, rejected_len=5)
result = _drop_long_sequences(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
self.assertTrue(result)
class TestRaiseOnLongSequences(unittest.TestCase):
"""Tests for _raise_on_long_sequences (excess_length_strategy='raise')."""
def setUp(self):
self.tokenizer = _FakeTokenizer()
def test_short_sample_passes(self):
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
result = _raise_on_long_sequences(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
self.assertTrue(result)
def test_long_sample_raises_valueerror(self):
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=2)
with self.assertRaises(ValueError, msg="excess_length_strategy"):
_raise_on_long_sequences(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
def test_kto_long_raises(self):
sample = _make_kto_sample(prompt_len=5, completion_len=10)
with self.assertRaises(ValueError):
_raise_on_long_sequences(
sample, RLType.KTO, self.tokenizer, sequence_len=10
)
def test_grpo_never_raises(self):
sample = {"prompt": "a " * 100}
result = _raise_on_long_sequences(
sample, RLType.GRPO, self.tokenizer, sequence_len=5
)
self.assertTrue(result)
class TestTruncateLongSequencesRL(unittest.TestCase):
"""Tests for _truncate_long_sequences_rl (excess_length_strategy='truncate')."""
def setUp(self):
self.tokenizer = _FakeTokenizer()
def test_dpo_short_sample_unchanged(self):
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
result = _truncate_long_sequences_rl(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
self.assertEqual(result["chosen"], sample["chosen"])
self.assertEqual(result["rejected"], sample["rejected"])
def test_dpo_truncates_chosen(self):
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=3)
result = _truncate_long_sequences_rl(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
# max_response_len = 10 - 5 = 5, chosen had 10 words -> truncated to 5
chosen_tokens = self.tokenizer(result["chosen"], add_special_tokens=False)[
"input_ids"
]
self.assertEqual(len(chosen_tokens), 5)
def test_dpo_truncates_rejected(self):
sample = _make_dpo_sample(prompt_len=5, chosen_len=3, rejected_len=10)
result = _truncate_long_sequences_rl(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
rejected_tokens = self.tokenizer(result["rejected"], add_special_tokens=False)[
"input_ids"
]
self.assertEqual(len(rejected_tokens), 5)
def test_dpo_truncates_both(self):
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=10)
result = _truncate_long_sequences_rl(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
chosen_len = len(
self.tokenizer(result["chosen"], add_special_tokens=False)["input_ids"]
)
rejected_len = len(
self.tokenizer(result["rejected"], add_special_tokens=False)["input_ids"]
)
self.assertEqual(chosen_len, 5)
self.assertEqual(rejected_len, 5)
def test_dpo_prompt_unchanged(self):
"""Prompt text should never be modified."""
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=10)
result = _truncate_long_sequences_rl(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
self.assertEqual(result["prompt"], sample["prompt"])
def test_dpo_prompt_exceeds_limit_returns_unchanged(self):
"""When prompt alone exceeds sequence_len, sample is returned as-is."""
sample = _make_dpo_sample(prompt_len=15, chosen_len=3, rejected_len=3)
result = _truncate_long_sequences_rl(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
self.assertEqual(result, sample)
def test_kto_truncates_completion(self):
sample = _make_kto_sample(prompt_len=5, completion_len=10)
result = _truncate_long_sequences_rl(
sample, RLType.KTO, self.tokenizer, sequence_len=10
)
completion_len = len(
self.tokenizer(result["completion"], add_special_tokens=False)["input_ids"]
)
self.assertEqual(completion_len, 5)
def test_kto_short_sample_unchanged(self):
sample = _make_kto_sample(prompt_len=3, completion_len=2)
result = _truncate_long_sequences_rl(
sample, RLType.KTO, self.tokenizer, sequence_len=10
)
self.assertEqual(result["completion"], sample["completion"])
def test_kto_prompt_exceeds_limit_returns_unchanged(self):
sample = _make_kto_sample(prompt_len=15, completion_len=3)
result = _truncate_long_sequences_rl(
sample, RLType.KTO, self.tokenizer, sequence_len=10
)
self.assertEqual(result, sample)
def test_grpo_unchanged(self):
sample = {"prompt": "a " * 100}
result = _truncate_long_sequences_rl(
sample, RLType.GRPO, self.tokenizer, sequence_len=5
)
self.assertEqual(result, sample)
def test_ipo_uses_dpo_logic(self):
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=3)
result = _truncate_long_sequences_rl(
sample, RLType.IPO, self.tokenizer, sequence_len=10
)
chosen_len = len(
self.tokenizer(result["chosen"], add_special_tokens=False)["input_ids"]
)
self.assertEqual(chosen_len, 5)
def test_does_not_mutate_original(self):
"""Verify immutability — original sample dict is not modified."""
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=10)
original_chosen = sample["chosen"]
original_rejected = sample["rejected"]
_truncate_long_sequences_rl(sample, RLType.DPO, self.tokenizer, sequence_len=10)
self.assertEqual(sample["chosen"], original_chosen)
self.assertEqual(sample["rejected"], original_rejected)
def test_dpo_missing_keys_raises(self):
with self.assertRaises(ValueError):
_truncate_long_sequences_rl(
{"prompt": "hi"}, RLType.DPO, self.tokenizer, 10
)
def test_kto_missing_keys_raises(self):
with self.assertRaises(ValueError):
_truncate_long_sequences_rl(
{"prompt": "hi"}, RLType.KTO, self.tokenizer, 10
)
def test_boundary_no_truncation_needed(self):
"""Samples exactly at sequence_len should not be modified."""
sample = _make_dpo_sample(prompt_len=5, chosen_len=5, rejected_len=5)
result = _truncate_long_sequences_rl(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
self.assertEqual(result["chosen"], sample["chosen"])
self.assertEqual(result["rejected"], sample["rejected"])

View File

@@ -2,6 +2,7 @@ import json
import math
from unittest.mock import Mock, patch
import pytest
import safetensors.torch
import torch
@@ -490,7 +491,8 @@ class TestEfficientMerge:
out_features = 4
alpha = 4
base = torch.randn(num_experts, in_features, out_features)
# PEFT ParamWrapper treats non-transposed 3D weights as (experts, out, in)
base = torch.randn(num_experts, out_features, in_features)
lora_a = torch.randn(r * num_experts, in_features)
lora_b = torch.randn(out_features, r * num_experts)
@@ -506,7 +508,7 @@ class TestEfficientMerge:
scale = alpha / r
wa = lora_a.reshape(num_experts, r, in_features)
wb = lora_b.reshape(out_features, r, num_experts)
manual_delta = torch.einsum("o r e, e r i -> e i o", wb, wa) * scale
manual_delta = torch.einsum("o r e, e r i -> e o i", wb, wa) * scale
for e in range(num_experts):
assert torch.allclose(merged[e], base[e] + manual_delta[e], atol=1e-5), (
f"Expert {e} mismatch"
@@ -773,8 +775,8 @@ class TestEfficientMerge:
"v_proj should be unchanged (no LoRA weights for it)"
)
def test_dora_missing_magnitude_falls_back(self):
"""DoRA without magnitude vector falls back to standard LoRA merge."""
def test_dora_missing_magnitude_raises(self):
"""DoRA with missing magnitude vector raises an explicit error."""
hidden = 16
r = 4
alpha = 8
@@ -791,11 +793,13 @@ class TestEfficientMerge:
}
config = {"r": r, "lora_alpha": alpha, "use_dora": True}
merged, was_merged = _merge_tensor_with_lora(
base, "layer.proj.weight", lora_state, scale, config, "cpu", use_dora=True
)
assert was_merged
# No magnitude vector → PEFT creates DoRA layer but with default magnitude,
# which produces a result different from plain W + scale * B @ A.
# Just verify it was merged (not unchanged).
assert not torch.equal(merged, base)
with pytest.raises(ValueError, match="DoRA merge requires a magnitude vector"):
_merge_tensor_with_lora(
base,
"layer.proj.weight",
lora_state,
scale,
config,
"cpu",
use_dora=True,
)

View File

@@ -5,6 +5,8 @@ Covers:
- save_strategy: 'best' requires metric_for_best_model
- streaming=True with val_set_size > 0 is rejected
- lora_target_modules with invalid regex patterns is rejected
- GRPO: generation batch size must be divisible by num_generations,
num_generations >= 2, and effective_gbs >= num_generations * world_size
"""
import pytest
@@ -117,3 +119,136 @@ class TestLoraTargetModulesRegexValidator:
)
with pytest.raises(ValueError, match="invalid regex pattern"):
validate_config(cfg)
class TestGRPOBatchSizeValidator:
"""GRPO requires (mb*GA) % num_generations == 0 and num_generations >= 2.
These call the @model_validator(mode="before") classmethod directly on a
plain dict — same input shape it receives during full Pydantic validation,
just without dragging in unrelated fields (datasets / model loading / etc.)
that aren't relevant to what's under test. The validator is registered on
``RLValidationMixin`` (which ``AxolotlInputConfig`` inherits) so this is the
same code path ``axolotl train`` exercises.
"""
@staticmethod
def _check(data):
from axolotl.utils.schemas.validation import RLValidationMixin
return RLValidationMixin.check_grpo_batch_size_divisibility(data)
def test_divisible_passes(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"trl": {"num_generations": 4},
}
# Should return data unchanged (no exception)
out = self._check(data)
assert out["trl"]["num_generations"] == 4
def test_non_divisible_raises(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 2,
"trl": {"num_generations": 4},
}
with pytest.raises(ValueError, match="num_generations"):
self._check(data)
def test_non_divisible_error_includes_fix_hint(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 3,
"trl": {"num_generations": 4},
}
with pytest.raises(ValueError, match="gradient_accumulation_steps: 4"):
self._check(data)
def test_num_generations_one_raises(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"trl": {"num_generations": 1},
}
with pytest.raises(ValueError, match=r"num_generations >= 2"):
self._check(data)
def test_explicit_generation_batch_size_divisible_passes(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"trl": {"num_generations": 4, "generation_batch_size": 8},
}
out = self._check(data)
assert out["trl"]["generation_batch_size"] == 8
def test_explicit_generation_batch_size_non_divisible_raises(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"trl": {"num_generations": 4, "generation_batch_size": 6},
}
with pytest.raises(ValueError, match="trl.generation_batch_size"):
self._check(data)
def test_non_grpo_skips_check(self):
# Anything other than rl=grpo should pass through untouched, even
# with non-divisible batch sizes — they're irrelevant to other RL
# methods that don't use group-relative advantages.
data = {
"rl": "dpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 3,
"trl": {"num_generations": 4},
}
assert self._check(data) is data
def test_no_rl_set_skips_check(self):
data = {
"micro_batch_size": 1,
"gradient_accumulation_steps": 3,
}
assert self._check(data) is data
def test_grpo_without_num_generations_skips_check(self):
# If num_generations isn't set, TRL uses its own default — we don't
# have enough info to validate, so the validator must short-circuit
# rather than guess.
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 3,
"trl": {},
}
out = self._check(data)
assert out["rl"] == "grpo"
def test_multi_rank_group_size_check(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 4, # gbs=4
"world_size": 2, # need gbs >= 4*2 = 8
"trl": {"num_generations": 4},
}
with pytest.raises(ValueError, match=r"world_size=2"):
self._check(data)
def test_multi_rank_group_size_satisfied(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 8, # gbs=8 >= 4*2
"world_size": 2,
"trl": {"num_generations": 4},
}
out = self._check(data)
assert out["gradient_accumulation_steps"] == 8