Compare commits

...

24 Commits

Author SHA1 Message Date
Wing Lian
cec99c4133 fix test dims 2026-04-20 20:45:19 -04:00
Wing Lian
d248242490 support for vllm 0.19.1 2026-04-19 18:09:46 -04:00
Wing Lian
323da791eb bump transformers to 5.5.4 and trl to latest 1.1.0 (#3603)
* bump transformers to 5.5.4 and trl to latest 1.1.0

* more upgrades

* update peft too

* adapt lora_merge to peft 0.19 layer config API

PEFT 0.19 requires a LoraConfig object on Linear/ParamWrapper/Conv
layer constructors and moved use_rslora, use_dora, fan_in_fan_out,
lora_dropout, and lora_bias into that config. Build the config
per branch in _build_peft_layer_and_get_delta so the merge utility
works with the upgraded peft.

* allow lora_dropout on mixed attention+MoE configs under peft 0.19

PEFT 0.19's convert_peft_config_for_transformers auto-remaps old MoE
target_modules (w1/w2/w3 on Mixtral, etc.) into target_parameters for
transformers v5's fused 3D expert Parameters. Those targets get wrapped
with ParamWrapper, which rejects lora_dropout != 0 because the 3D
einsum can't factor dropout out of lora_B(lora_A(dropout(x))).

Monkeypatch ParamWrapper.__init__ to internally use a copy of the
LoraConfig with lora_dropout=0, so its dropout slot becomes nn.Identity
while the shared config still delivers real dropout to sibling Linear
LoRA layers (attention q/k/v/o). A probe runs the same conversion on a
deep copy to detect the situation and emit a warning before patching.
2026-04-15 09:27:03 -04:00
NanoCode012
6990478163 fix: rename model to adapter_model for fsdp sharded final model (#3585)
* fix: rename model to adapter_model for fsdp sharded final model

* fix: follow upstream transformer shard size

* fix: handle multiple model files

* fix redundant condition, tighten to safetensors, keep shard size small

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-12 20:51:30 -04:00
ゆり
63a58cfec1 feat: support excess_length_strategy for RL trainers (#3578) [skip ci]
* feat: support excess_length_strategy for RL trainers

Previously, RL data loading always dropped sequences exceeding
sequence_len. This adds support for the existing `excess_length_strategy`
config option (`drop`, `truncate`, `raise`) in RL training pipelines,
matching the behavior already available for SFT.

- `drop` (default): unchanged behavior, filters out long samples
- `truncate`: tokenizes text components, truncates responses to fit
  within sequence_len while preserving the full prompt, then decodes
  back to text. Handles DPO/IPO/ORPO/SIMPO and KTO datasets.
- `raise`: raises ValueError if any sample exceeds sequence_len

Closes #3547

* improve RL truncation strategy robustness and performance

---------

Co-authored-by: yurekami <yurekami@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-12 20:51:10 -04:00
madScientist10
3985ec2f67 feat: add FineGrainedFP8Config support for model quantization (#3587) [skip ci]
Allow loading FP8-quantized models (e.g. Mistral-Small-4-119B) with
FineGrainedFP8Config and optional dequantize kwarg for full fine-tuning.

Made-with: Cursor
2026-04-12 20:50:37 -04:00
Joaquin Hui
a44edda6d7 Skip redundant evaluation when resuming from checkpoint (#3575) [skip ci]
* Skip redundant evaluation when resuming from checkpoint

* add condition check for adding callback

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-12 20:50:15 -04:00
Wing Lian
66c3e5a3fd better handling of dora merge on Conv layers in Qwen 3.5 (#3599)
* better handling of dora merge on Conv layers in Qwen 3.5

* address issues from code review

* stricter efficient merges for dora since we now have meta model to reference
2026-04-12 10:57:45 -04:00
Wing Lian
b8358aa5ab [gemma4] use mixed Flash Attention and SDPA and add fused RMSNorm+RoPE Triton kernels (#3598) 2026-04-12 10:29:55 -04:00
Joaquin Hui
e079cf16a2 qwen3_5.jinja: handle list content on system messages (#3595) [skip ci]
* qwen3_5.jinja: handle list content on system messages

The system message branch used string concatenation on
messages[0].content, which breaks when the first system message uses
the OpenAI-style list-of-parts format that multimodal datasets require.
User and assistant branches already handle both string and list content,
but the system branch did not.

Check whether content is a string and fall back to iterating over parts
when it is a list, matching the pattern used for user messages.

Fixes #3590

* Address pr for other content types

---------

Co-authored-by: Joaquin Hui Gomez <joaquinhuigomez@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-12 00:58:58 -04:00
Wing Lian
e2f69828d2 [fix][fsdp2] clone sharded param so original full size shard can be gc'ed (#3597) [skip ci] 2026-04-11 20:22:35 -04:00
Wing Lian
122b50bad6 pre-cache the eot token ids rather than on each iteration (#3594) [skip ci] 2026-04-11 20:05:21 -04:00
Wing Lian
e77a185e86 upgrade transformers to use v5.5.3 (#3593) 2026-04-10 17:08:14 -04:00
Wing Lian
29fa4dedbb Gemma4 fixes and profiler (#3591) 2026-04-10 16:46:17 -04:00
Wing Lian
315cdeede9 handle trainable/masked spans in content and reasoning content (#3592) 2026-04-10 14:11:10 -04:00
NanoCode012
e7a6a5b529 fix: move warning after we've set any overrides (#3589) [skip ci] 2026-04-10 13:00:47 -04:00
NanoCode012
bfb4da1d25 fix: document jinja2 file path support (#3588) [skip ci] 2026-04-10 13:00:26 -04:00
floaty3
4dfa0a59b2 Add uninstall command to cut_cross_entropy import message (#3583) [skip ci] 2026-04-10 13:00:07 -04:00
Wing Lian
4ef608dda3 fix ddp/fsdp w gemma4 (#3584)
* fix ddp/fsdp w gemma4

* address pr comments

* activation offloading fix and update agent docs for gemma4
2026-04-09 20:02:36 -07:00
NanoCode012
7daf7d96f1 fix: regex for unfrozen language tower (#3586) [skip ci]
* fix: regex for unfrozen language tower

* fix: other leftover regex
2026-04-08 08:18:11 -07:00
Wing Lian
7c56809c7f use vllm 0.19.0 for torch 2.10.0 (#3582) 2026-04-07 08:09:49 -07:00
NanoCode012
149178ddb7 chore: cleanup post release v0.16 (#3577)
* fix: remove unneeded debug log

* fix: cleanup

* feat: add dense gemma config and cleanup

* feat: add cce support

* update notes and set torch compile

* fix patch for new number of return vals

* fixes for gemma4

* fix packing bug

* use updated cce for mm

* fix: pass in kv cache func when avail for transformers 5.5

* feat: update examples with flex variant and readme

* gemma4 lora attention kernels

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-06 10:10:52 -07:00
NanoCode012
dc638e723f fix(config): add cce and liger to nemotron-h example (#3573) [skip ci] 2026-04-06 10:10:25 -07:00
Wing Lian
6f15da4cac make it easier for agents to discover docs (#3579) [skip ci]
* make it easier for agents to discover docs

* fixup pr comments
2026-04-06 10:00:55 -07:00
62 changed files with 6136 additions and 189 deletions

View File

@@ -220,6 +220,16 @@ jobs:
run: |
axolotl --help
- name: Verify agent docs are discoverable
run: |
# Agent docs live in docs/agents/ (source of truth) and are resolved
# at runtime from the repo checkout or via `axolotl fetch docs`
axolotl agent-docs --list
axolotl agent-docs | grep -q "Fine-tuning framework"
axolotl agent-docs grpo | grep -q "GRPO"
axolotl agent-docs sft | grep -q "SFT"
python -c "from axolotl.cli.agent_docs import get_doc, list_topics; assert len(list_topics()) >= 5; assert 'GRPO' in get_doc('grpo')"
- name: Show HF cache
run: hf cache ls

View File

@@ -16,6 +16,9 @@ axolotl inference config.yaml # Interactive inference
axolotl merge-lora config.yaml # Merge LoRA adapter into base model
axolotl vllm-serve config.yaml # Start vLLM server for GRPO/EBFT training
axolotl fetch examples # Download example configs
axolotl agent-docs # Show agent-optimized docs (bundled with pip package)
axolotl agent-docs grpo # Topic-specific agent reference
axolotl config-schema # Dump config JSON schema
```
## Training Methods
@@ -35,6 +38,8 @@ Agent-specific references:
- [docs/agents/grpo.md](docs/agents/grpo.md) — GRPO online RL with reward functions
- [docs/agents/reward_modelling.md](docs/agents/reward_modelling.md) — outcome and process reward models
- [docs/agents/pretraining.md](docs/agents/pretraining.md) — continual pretraining
- [docs/agents/model_architectures.md](docs/agents/model_architectures.md) — model-specific quirks (Gemma4, Qwen3.5 MoE, etc.)
- [docs/agents/new_model_support.md](docs/agents/new_model_support.md) — debugging and adding support for new model architectures
## Config Pattern

View File

@@ -3,4 +3,6 @@ include README.md
include LICENSE
include src/setuptools_axolotl_dynamic_dependencies.py
include src/axolotl/utils/chat_templates/templates/*.jinja
include AGENTS.md
recursive-include docs/agents *.md
recursive-include axolotl *.py

View File

@@ -86,7 +86,7 @@ Features:
**Requirements**:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- Python >=3.11 (3.12 recommended)
- PyTorch ≥2.9.1
### Google Colab
@@ -95,6 +95,34 @@ Features:
### Installation
#### Using uv (recommended)
```bash
# install uv if you don't already have it installed
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
# CUDA 12.8.1 tends to have better package compatibility
export UV_TORCH_BACKEND=cu128
# create a new virtual environment
uv venv --python 3.12
source .venv/bin/activate
uv pip install torch==2.10.0 torchvision
uv pip install --no-build-isolation axolotl[deepspeed]
# recommended - install cut-cross-entropy
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
# (optional) - prefetch flash-attn2 and causal-conv1d kernels
uv run --python 3.12 python -c "from kernels import get_kernel; get_kernel('kernels-community/flash-attn2'); get_kernel('kernels-community/causal-conv1d')"
# Download example axolotl configs, deepspeed configs
axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL
```
#### Using pip
```bash
@@ -157,6 +185,29 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
- [API Reference](https://docs.axolotl.ai/docs/api/) - Auto-generated code documentation
- [FAQ](https://docs.axolotl.ai/docs/faq.html) - Frequently asked questions
## AI Agent Support
Axolotl ships with built-in documentation optimized for AI coding agents (Claude Code, Cursor, Copilot, etc.). These docs are bundled with the pip package — no repo clone needed.
```bash
# Show overview and available training methods
axolotl agent-docs
# Topic-specific references
axolotl agent-docs sft # supervised fine-tuning
axolotl agent-docs grpo # GRPO online RL
axolotl agent-docs preference_tuning # DPO, KTO, ORPO, SimPO
axolotl agent-docs reward_modelling # outcome and process reward models
axolotl agent-docs pretraining # continual pretraining
axolotl agent-docs --list # list all topics
# Dump config schema for programmatic use
axolotl config-schema
axolotl config-schema --field adapter
```
If you're working with the source repo, agent docs are also available at `docs/agents/` and the project overview is in `AGENTS.md`.
## 🤝 Getting Help
- Join our [Discord community](https://discord.gg/HhrNrHJPRb) for support

View File

@@ -0,0 +1,198 @@
# Model Architectures — Agent Reference
Model-specific quirks, required settings, and known issues. Check this before debugging training failures on specific model families.
## VLM (Vision Language Model) Quick Start
All VLM configs require these four lines:
```yaml
processor_type: AutoProcessor
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
```
Decision tree for VLM config:
```text
Is the model multimodal (has vision/audio encoder)?
├─ YES: Add `freeze_mm_modules: true` if training text only
│ Add `chat_template: <model_template>` (e.g. gemma4, qwen3_5, gemma3)
│ LoRA: use regex `lora_target_modules` to restrict to language model
└─ NO: Train as a regular text model
Is the model MoE (e.g. Gemma4 26B-A4B, Qwen3.5 35B-A3B)?
├─ YES: Add `lora_target_parameters` for expert LoRA
│ Consider ScatterMoE kernels (see Plugins section)
└─ NO: Standard LoRA config
```
## Plugins & Optimizations
### Cut Cross Entropy (CCE)
Computes loss from hidden states + lm_head weight without materializing the full logits tensor, saving significant VRAM. Install if not already present:
```bash
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
```
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
```
### ScatterMoE Kernels
Fuses expert + LoRA computation into a single kernel for MoE models. Significant speedup for models with many experts.
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
# Expert LoRA targets (3D parameter tensors, not nn.Linear):
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
Supported: Gemma4 (`gemma4_text`), Mixtral, Qwen MoE variants. The plugin auto-detects model type and routing function. Without ScatterMoE, expert LoRA still works but runs base expert matmul and LoRA as separate operations.
## Gemma 4
**Models**: `google/gemma-4-26B-A4B` (MoE), `google/gemma-4-31B` (dense), `google/gemma-4-E2B`, `google/gemma-4-E4B`
**Architecture**: Multimodal wrapper (`Gemma4ForConditionalGeneration`) over a text backbone (`Gemma4TextModel`), with optional vision/audio encoders. All Gemma4 HF repos have `model_type: "gemma4"` — even text-only variants load as multimodal with a vision tower.
### Required settings
```yaml
# Always needed for Gemma4:
freeze_mm_modules: true # Freeze vision/audio encoders for text-only training
gradient_checkpointing_kwargs:
use_reentrant: false # Shared per-layer norms cause "marked ready twice" with reentrant
# LoRA target — restrict to language model only (DO NOT use lora_target_linear: true):
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
```
### Auto-detection
Axolotl auto-detects Gemma4 and applies:
- `use_reentrant: false` for gradient checkpointing
- `ddp_find_unused_parameters: true` for DDP (skipped when `activation_offloading: true`)
### Multi-GPU
| Strategy | Works? | Notes |
|----------|--------|-------|
| DDP | Yes | Auto-sets `ddp_find_unused_parameters=True` |
| DDP + activation_offloading | Yes | `find_unused_parameters` is skipped (conflicts with checkpoint wrappers) |
| FSDP1 | No | OOM during dequantization/sharding with QLoRA |
| FSDP2 | Yes | Use `Gemma4TextDecoderLayer` (not `Gemma4DecoderLayer`) as wrap class |
| FSDP2 + activation_offloading | Yes | Lowest VRAM (~26 GiB/GPU for 26B-A4B) |
FSDP2 config:
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_version: 2
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer
```
### MoE (26B-A4B)
- `enable_moe_block: true`, 256 experts, top-k routing
- No separate `SparseMoeBlock` — MoE is embedded in each decoder layer
- Expert LoRA targets 3D parameter tensors:
```yaml
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
- ScatterMoE kernel acceleration:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
```
### VLM (Vision) Training
All Gemma4 models load as `Gemma4ForConditionalGeneration` with a vision tower. No custom `ProcessingStrategy` needed — the base class auto-detects the image token.
```yaml
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B
processor_type: AutoProcessor
freeze_mm_modules: true
chat_template: gemma4
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
```
A starting VLM loss of ~8-15 is typical. In most runs, loss converges below 1.0 within ~30-50 steps, though results may vary across configurations.
For the 26B-A4B MoE variant with ScatterMoE + expert LoRA + CCE, add:
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
### Common issues
| Symptom | Cause | Fix |
|---------|-------|-----|
| `mm_token_type_ids is required` in DDP | `model.config` not accessible through DDP wrapper | Already fixed — `unwrap_model()` in `compute_loss` and `prediction_step` |
| `marked a variable ready twice` in DDP | `ddp_find_unused_parameters=True` + activation_offloading checkpoint wrappers | Auto-handled — `find_unused_parameters` is skipped when `activation_offloading: true` |
| Loss ~12 instead of ~0.5 | Using `lora_target_linear: true` (applies LoRA to vision/audio modules) | Use the regex `lora_target_modules` pattern instead |
| FSDP2 `Could not find Gemma4AudioLayer` | Auto-wrap detects `_no_split_modules` including audio layers that don't exist | Explicitly set `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer` |
| `Gemma4ClippableLinear not supported` by PEFT | Vision tower uses a non-standard linear wrapper | Axolotl patches this automatically via `_patch_peft_clippable_linear()` |
### E2B/E4B dense models
These have `hidden_size_per_layer_input: 256` (per-layer input embeddings) and `attention_k_eq_v: False`. Known issue: loss starts higher than expected (~12 vs ~0.5 for 26B). Root cause under investigation — may be related to the per-layer input mechanism or the `Gemma4ForConditionalGeneration` loss computation.
## Gemma 3
**Models**: `google/gemma-3-*`
- `ddp_find_unused_parameters: true` needed (multimodal unused params)
- `use_reentrant: false` recommended
- Attention mask must be dropped for sample packing (handled automatically)
- Multi-GPU test currently skipped (`tests/e2e/multigpu/test_gemma3.py`)
## Qwen 3.5 MoE
**Models**: `Qwen/Qwen3.5-35B-A3B`
- Hybrid architecture: DeltaNet linear attention (30 layers) + full attention (10 layers)
- 256 experts, 8 active per token
- Known weight scale drift in late DeltaNet layers (36-38) due to AdamW + rare expert interaction
- Fix: `normalize_weight_scales` config to detect and rescale outliers:
```yaml
normalize_weight_scales:
- name_pattern: 'linear_attn\.conv1d\.weight'
threshold: 1.3
```
## General MoE Notes
- `lora_target_linear: true` with multimodal MoE models will apply LoRA to ALL linear modules including vision/audio encoders — use regex `lora_target_modules` to restrict to language model only
- Rare experts get larger effective learning rate from AdamW (small second-moment estimates) — can cause weight drift in recurrent/SSM components. Use `normalize_weight_scales` with `dry_run: true` to detect.
- For ScatterMoE kernel support, set `experts_implementation: scattermoe` and add the KernelsPlugin

View File

@@ -0,0 +1,181 @@
# New Model Support — Agent Reference
Guide for debugging and adding support for new model architectures in axolotl. Based on lessons learned from Gemma4, Gemma3, Qwen2-VL, and other multimodal/MoE models.
## Quick Validation Checklist
When testing a new model, run through these checks in order:
1. **Does the model load?** `axolotl preprocess config.yaml` — catches config schema errors
2. **Does LoRA apply?** Check for "Unsupported layer type" warnings from PEFT
3. **Is the initial loss sane?** First-step loss for a pretrained model should be 0.52.0 for SFT
4. **Does sample packing work?** Compare loss with `sample_packing: true` vs `false` — should be similar
5. **Is CCE active?** Check for "Applying Cut Cross Entropy" log and verify peak VRAM is lower
## Loss Debugging
### Expected initial loss
A pretrained model doing SFT should start with loss roughly in the 0.52.0 range. If loss starts above 3.0, something is wrong. If it's near `log(vocab_size)` (≈ 12 for 262K vocab), the model is predicting at random — attention masking or model weights are broken.
### Direct comparison technique
The fastest way to isolate a loss issue — bypass the trainer entirely:
```python
# Load model via axolotl's pipeline (applies all patches)
from axolotl.cli.config import load_cfg
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.loaders.model import ModelLoader
cfg = load_cfg("your_config.yaml")
normalize_config(cfg)
prepare_plugins(cfg)
tokenizer = load_tokenizer(cfg)
model, _ = ModelLoader(cfg, tokenizer).load()
# Forward pass on preprocessed data
model.train()
out = model(input_ids, labels=labels)
print(f"Direct loss: {out.loss.item()}") # Compare to trainer's reported loss
```
If direct loss is correct (~1.0) but trainer reports 34x higher, check `model_accepts_loss_kwargs` (see below).
### `model_accepts_loss_kwargs` inflation
HF Trainer checks if the model's `forward()` has `**kwargs` and sets `model_accepts_loss_kwargs=True`. This changes loss normalization: the trainer does NOT divide loss by `gradient_accumulation_steps` before logging. The gradient is correct — only the logged loss is inflated.
**Symptom**: Logged loss ≈ actual_loss × gradient_accumulation_steps.
**Which models are affected**: Any model with `**kwargs` in forward (common in multimodal models for extra inputs like `mm_token_type_ids`, `pixel_values`, etc.).
**Fix location**: `src/axolotl/core/trainers/base.py` `__init__()` — after `super().__init__()`, check if the unwrapped model actually has `num_items_in_batch` in its forward signature. If not, set `self.model_accepts_loss_kwargs = False`.
## Multimodal Models (ForConditionalGeneration)
Many recent models use `ForConditionalGeneration` as the top-level class, not `ForCausalLM`:
- Gemma3 → `Gemma3ForConditionalGeneration`
- Gemma4 → `Gemma4ForConditionalGeneration`
- Qwen2-VL → `Qwen2VLForConditionalGeneration`
- LLaVA → `LlavaForConditionalGeneration`
### Why this matters
| Component | Targets `ForCausalLM` | Needs `ForConditionalGeneration` |
|-----------|----------------------|--------------------------------|
| CCE patches | ✅ (default) | ❌ silently inactive if not patched |
| PEFT LoRA | ✅ | May fail on custom layer types |
| HF Trainer label handling | ✅ | May need extra inputs |
### Required extra inputs
Multimodal models require special inputs during training even for text-only data:
| Model | Required Input | Value for Text-Only |
|-------|---------------|-------------------|
| Gemma4 | `mm_token_type_ids` | `torch.zeros_like(input_ids)` |
| Gemma3 | `token_type_ids` | `torch.zeros_like(input_ids)` |
Auto-inject in `compute_loss()` when not provided by the data collator. See `core/trainers/base.py`.
### Custom layer types and PEFT
Vision towers often use custom module wrappers that PEFT doesn't support:
| Model | Custom Layer | Wraps | Fix |
|-------|-------------|-------|-----|
| Gemma4 | `Gemma4ClippableLinear` | `nn.Linear` | Redirect to `.linear` child |
Fix location: `src/axolotl/loaders/adapter.py` `_patch_peft_clippable_linear()`.
## Sample Packing
### How packed sequence detection works (transformers ≥ 5.x)
`transformers.masking_utils._preprocess_mask_arguments()` detects packed sequences from `position_ids` resets. But **only when `attention_mask is None`**:
```python
# From masking_utils.py:
if position_ids is not None and attention_mask is None and past_key_values is None:
packed_sequence_mask = find_packed_sequence_indices(position_ids)
```
If the collator provides an all-ones `attention_mask`, packing detection is **skipped** and the model builds a single causal mask spanning all packed sequences → cross-sequence attention leakage → very high loss.
### Fix for models using `create_causal_mask_mapping`
For Gemma3, Gemma4, and similar models that use the new transformers masking system, remove `attention_mask` from inputs when sample packing is active:
```python
# In compute_loss():
if (
self.args.sample_packing
and model_type in ("gemma4", "gemma3")
and "attention_mask" in inputs
and "position_ids" in inputs
):
del inputs["attention_mask"]
```
Fix location: `src/axolotl/core/trainers/base.py` `compute_loss()`.
### Models that DON'T need this fix
Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2, etc.) handle sample packing via axolotl's multipack attention monkeypatch instead. Only models using the new `create_causal_mask_mapping` / `create_causal_mask` masking system need the `attention_mask` removal.
## Attention Backend Selection
| Backend | Config | head_dim limit | torch_compile | Notes |
|---------|--------|---------------|---------------|-------|
| FA2 | `flash_attention: true` | 256 | ✅ | Fastest when supported |
| FA4 | auto with `flash_attention: true` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
| SDPA | `sdp_attention: true` | None | ✅ | Universal fallback |
| flex | `flex_attention: true` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
| eager | neither set | None | ✅ | Slowest, always works |
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.
**head_dim gotcha**: The 256 limit is specific to flash-attn CUDA kernels, NOT PyTorch-level. SDPA and flex_attention both handle arbitrary head_dim. Models with `global_head_dim > 256` (Gemma4: 512) must use SDPA or flex.
**flex + compile gotcha**: `torch_compile` with flex_attention can hit Triton shared memory OOM for large head_dim. Falls back to eager per-function (not a crash, but slower). Unsloth disables flex for Gemma4 for this reason.
## Cut Cross Entropy (CCE)
### How CCE patches work
CCE replaces the model's `forward()` with a fused version that computes loss from hidden states + lm_head weight without materializing the full logits tensor. This saves ~`batch × seq_len × vocab_size × dtype_bytes` of VRAM.
### Adding CCE for a new model
1. Check if the model type is in `cut_cross_entropy.transformers.patch.PATCH_FNS`
2. If not, axolotl's generic fallback (`integrations/cut_cross_entropy/__init__.py` `patch_llama_like()`) patches `{Prefix}ForCausalLM.forward` with `cce_forward`
3. For multimodal models (`ForConditionalGeneration`), a model-specific patch is needed in `ml-cross-entropy` repo
4. The multimodal `cce_forward` must accept all extra kwargs (pixel_values, mm_token_type_ids, etc.) and pop any that would conflict before calling `self.model()`
### Common CCE pitfall
If CCE appears active (log says "Applying Cut Cross Entropy") but peak VRAM doesn't decrease, check which class was patched. If the model loads as `ForConditionalGeneration` but CCE patched `ForCausalLM`, the patch is silently inactive.
## MoE Models
### Dense MLP vs MoE experts
Some MoE models (e.g., Gemma4) have BOTH dense MLP layers and MoE expert layers at every decoder layer:
- `gate_proj/up_proj/down_proj` → targets the **dense MLP** (`Gemma4TextMLP`)
- `experts.gate_up_proj/experts.down_proj` → targets the **MoE experts** (`Gemma4TextExperts`)
LoRA on the dense MLP works normally. Expert LoRA via `lora_target_parameters` requires PEFT support for the specific expert module type (may warn "Unsupported layer type").
### ScatterMoE kernels
`use_scattermoe: true` with `experts_implementation: scattermoe` registers fused expert kernels via transformers' `ExpertsInterface`. Significant speedup for MoE models. Requires the kernels plugin:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
```
## Where to Add Model-Specific Fixes
| What | Where | Example |
|------|-------|---------|
| Missing forward inputs | `core/trainers/base.py` `compute_loss()` | mm_token_type_ids injection |
| Attention mask fixes | `core/trainers/base.py` `compute_loss()` | Sample packing mask removal |
| Loss logging fixes | `core/trainers/base.py` `__init__()` | model_accepts_loss_kwargs override |
| PEFT/LoRA patches | `loaders/adapter.py` | ClippableLinear redirect |
| Attention patches | `monkeypatch/attention/` | FA4 tuple fix |
| Model-specific patches | `loaders/patch_manager.py` `_apply_model_specific_patches()` | Llama4, Kimi, NemotronH |
| CCE patches | `ml-cross-entropy` repo `transformers/` | Per-model cce_forward |
| Example configs | `examples/<model>/` | Validated YAML |
| Config validation | `utils/schemas/validation.py` | Compatibility checks |

View File

@@ -91,6 +91,30 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
| FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` |
| DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
## Profiling
To profile training and identify optimization opportunities:
```yaml
# Profile steps 3-7 (after warmup/autotuning settles)
profiler_steps_start: 3
profiler_steps: 5
```
This produces `profiler_trace.json` (Chrome trace) and `snapshot.pickle` (memory snapshot) in `output_dir`.
View the Chrome trace at `chrome://tracing`.
To programmatically inspect the trace:
```bash
python scripts/analyze_profile.py output_dir/
```
The trace shows per-kernel CUDA times, memory allocations, and operator-level breakdown. Look for:
- **Large matmul kernels**: candidates for fusion or quantization
- **Memory copies (H2D/D2H)**: unnecessary data movement
- **Small frequent kernels**: candidates for kernel fusion
- **Gaps between kernels**: pipeline bubbles from CPU overhead
Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd)
## File Map

View File

@@ -108,6 +108,14 @@ datasets:
type: chat_template
```
::: {.callout-tip}
`chat_template_jinja` also accepts a file path to a `.jinja2` file instead of an inline string:
```yaml
chat_template_jinja: ./path/to/my_template.jinja2
```
:::
::: {.callout-important}
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
:::
@@ -294,6 +302,113 @@ datasets:
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
:::
#### Content parts with per-part training control
Instead of using character offsets with `train_detail`, you can split a message's content into a list of parts, each with its own training flag. This is useful when you want to mask specific sections of a response (e.g., mask reasoning but train on the answer).
```{.json filename="data.jsonl"}
{
"messages": [
{"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me think step by step...", "train": false},
{"type": "text", "text": " The answer is 4.", "train": true}
]
}
]
}
```
The configuration is the same as standard `chat_template` — no extra fields needed:
```yaml
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
Each content part supports:
- `type`: `"text"` (required)
- `text`: the text value (also accepts `content` or `value` as the key)
- `train`: `true`/`false` (optional) — whether to train on this part
- `weight`: `0`/`1` (optional) — alternative to `train`
If a part has no `train` or `weight` flag, it inherits the turn-level training decision (from `roles_to_train`, `message_field_training`, or `train_on_inputs`).
::: {.callout-warning title="Whitespace at part boundaries"}
BPE tokenizers (used by Llama, Qwen, Mistral, GPT, etc.) prepend spaces to word tokens. For example, `" answer"` is a single token — the space is part of it. This means **where you place whitespace between content parts matters**:
**Split BEFORE spaces** (space goes with the next part):
```json
[
{"type": "text", "text": "Let me think...", "train": false},
{"type": "text", "text": " The answer is 4.", "train": true}
]
```
**DON'T put trailing spaces** on a part (the space merges with the next word into one token that straddles the boundary, and straddling tokens are masked):
```json
[
{"type": "text", "text": "Let me think... ", "train": false},
{"type": "text", "text": "The answer is 4.", "train": true}
]
```
In the bad example, `" The"` becomes a single token that spans both parts. Because it straddles the boundary, it is conservatively **masked** (not trained) — even though the second part has `train: true`.
**Newlines** typically merge with preceding punctuation (e.g., `":\n"` is one token). Keep newlines with the preceding part:
```json
[
{"type": "text", "text": "Thinking:\n", "train": false},
{"type": "text", "text": "The answer is 4.", "train": true}
]
```
Axolotl will log a warning if it detects trailing whitespace at a boundary between parts with different training flags.
:::
::: {.callout-note}
When all content parts in a message are strings, they are concatenated before being passed to the chat template. This means content parts work with **any** Jinja template — the template sees a plain string, and the per-part training flags are applied during tokenization.
:::
##### Per-part training on reasoning_content
For templates that support a separate `reasoning_content` field (e.g., `qwen3`), the same content-parts format works on `reasoning_content`. This is useful for masking incorrect reasoning steps while training on self-corrections:
```{.json filename="data.jsonl"}
{
"messages": [
{"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]},
{
"role": "assistant",
"reasoning_content": [
{"type": "text", "text": "Hmm maybe 2+2=5.", "train": false},
{"type": "text", "text": " Wait no, 2+2=4.", "train": true}
],
"content": [
{"type": "text", "text": "The answer is 4.", "train": true}
]
}
]
}
```
The `reasoning_content` and `content` fields are handled independently — each has its own token boundaries and per-part masking. No additional configuration is needed beyond what the template already requires.
::: {.callout-tip}
When `reasoning_content` is provided as a separate field, `split_thinking` is not needed — the reasoning is already separated from the content in the data.
:::
The same whitespace rules apply to `reasoning_content` parts as to `content` parts — split before spaces, keep newlines with the preceding part.
#### Reasoning split
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.

View File

@@ -8,6 +8,7 @@ format:
## Supported Models
- [Gemma-4](#sec-gemma-4) *(NEW)*
- [Mllama](#sec-mllama)
- [Llama4](#sec-llama4)
- [Pixtral](#sec-pixtral)
@@ -138,6 +139,40 @@ base_model: mistralai/Voxtral-Mini-3B-2507
processor_type: VoxtralProcessor
```
### Gemma-4 {#sec-gemma-4}
All Gemma 4 variants (E2B, E4B, 26B-A4B, 31B) load as multimodal models even for text-only training.
```yaml
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B, 31B
chat_template: gemma4
freeze_mm_modules: true # freeze vision/audio encoders for text-only or vision LoRA
# For the 26B-A4B MoE model, enable ScatterMoE and expert LoRA:
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
# MoE expert LoRA (3D tensors, not nn.Linear) — only for 26B-A4B:
lora_target_parameters:
- experts.gate_up_proj
- experts.down_proj
```
::: {.callout-warning}
Gemma 4 VLM training starts with high loss (~8-15). This is expected — see the [training stability guide](training_stability.qmd) for details.
:::
::: {.callout-tip}
For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True`. However, when `activation_offloading: true`, `ddp_find_unused_parameters` is skipped (checkpoint wrappers conflict with it); use `freeze_mm_modules: true` instead to handle unused vision/audio params. For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`.
:::
### Gemma-3 {#sec-gemma-3}
::: {.callout-tip}

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
]
},
{

View File

@@ -26,8 +26,8 @@ output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
- ^model.language_model.*
- ^lm_head.*
adapter: qlora
lora_r: 32

View File

@@ -26,8 +26,8 @@ output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
- ^model.language_model.*
- ^lm_head.*
adapter: qlora
lora_r: 32

View File

@@ -22,8 +22,8 @@ output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
- ^model.language_model.*
- ^lm_head.*
adapter: qlora
lora_model_dir:

View File

@@ -1,19 +1,12 @@
# Gemma 4 26B-A4B MoE QLoRA with ScatterMoE kernels
#
# Validated: 50 steps on FineTome-100k, loss 7.4 -> 2.4, single RTX 5090 (32GB)
# Validated: 50 steps on FineTome-100k, loss 8.8 -> 1.8, single RTX 5090 (32GB)
# torch_compile=true: 21 GiB peak VRAM, ~230 tok/s, 336s total
#
# Key notes:
# - Flash Attention 2 is NOT supported (global_head_dim=512 > FA2 max of 256).
# Use sdp_attention instead.
# - Gemma 4 is multimodal (text+vision+audio). For text-only SFT, restrict
# LoRA to the text backbone via lora_target_linear_modules regex.
# - MoE experts use `experts_implementation: scattermoe` — Gemma 4 embeds MoE
# directly in the decoder layer (no SparseMoeBlock), so we register ScatterMoE
# via the transformers ExpertsInterface.
# - Expert LoRA targets are `experts.gate_up_proj` / `experts.down_proj`
# (no `mlp.` prefix, unlike Qwen/Mixtral).
# - micro_batch_size: 1 fits 2048 seq_len on 32GB GPU with SDP attention.
# Use micro_batch_size: 4 with 1024 seq_len, or on 48GB+ GPUs.
# - Max sequence length on 32GB GPU: 2048 (micro_batch_size=1, SDP attention).
# 4096 seq_len OOMs due to head_dim=512 math SDP materializing full score matrix.
# Use 48GB+ GPUs for longer sequences or multi-GPU with FSDP.
base_model: google/gemma-4-26B-A4B
@@ -24,7 +17,7 @@ plugins:
use_kernels: true
use_scattermoe: true
experts_implementation: scattermoe
torch_compile: false
torch_compile: true
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
@@ -54,12 +47,9 @@ lora_r: 16
lora_alpha: 32
lora_dropout: 0
# Restrict LoRA to text backbone only (skip vision/audio encoders).
# lora_target_modules is intentionally empty — all module targeting is done
# via regex in lora_target_linear_modules below.
lora_target_modules: []
lora_target_linear_modules:
- language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj
# Restrict LoRA to text backbone only (skip vision/audio encoders)
# using regex to match only the text decoder attention projections.
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
# MoE expert LoRA (3D Parameter tensors, not nn.Linear)
lora_target_parameters:
@@ -73,7 +63,7 @@ lora_o_kernel: false
bnb_config_kwargs:
bnb_4bit_use_double_quant: true
wandb_project: gemma4-qlora
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
@@ -93,8 +83,7 @@ gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
# FA2 not supported — Gemma4 global_head_dim=512 exceeds FA2 max of 256
flash_attention: false
# FA2 not supported
sdp_attention: true
warmup_ratio: 0.1

View File

@@ -0,0 +1,71 @@
base_model: google/gemma-4-31B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.liger.LigerPlugin
torch_compile: true
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_rms_norm_gated: true
strict: false
chat_template: gemma4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:10%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.05
output_dir: ./outputs/gemma4-31b-qlora-flex
sequence_len: 2048
sample_packing: true
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
# Restrict LoRA to text backbone only (skip vision/audio encoders)
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
bnb_config_kwargs:
bnb_4bit_use_double_quant: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
# FA not supported
flex_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -0,0 +1,69 @@
base_model: google/gemma-4-31B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.liger.LigerPlugin
torch_compile: false
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_rms_norm_gated: true
strict: false
chat_template: gemma4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:10%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.05
output_dir: ./outputs/gemma4-31b-qlora
sequence_len: 2048
sample_packing: true
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
# Restrict LoRA to text backbone only (skip vision/audio encoders)
# using regex to match only the text decoder attention projections.
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
bnb_config_kwargs:
bnb_4bit_use_double_quant: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
# FA not supported
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

60
examples/gemma4/README.md Normal file
View File

@@ -0,0 +1,60 @@
# Finetune Google's Gemma 4 with Axolotl
[Gemma 4](https://huggingface.co/collections/google/gemma-4) is a family of multimodal models from Google. This guide covers how to train them with Axolotl.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
# 26B MoE QLoRA (1x80GB @ ~50 GiB)
axolotl train examples/gemma4/26b-a4b-moe-qlora.yaml
# 31B Dense QLoRA (1x80GB @ ~44 GiB)
axolotl train examples/gemma4/31b-qlora.yaml
# 31B Dense QLoRA Flex Attn (1x80GB @ ~26 GiB)
axolotl train examples/gemma4/31b-qlora-flex.yaml
```
### MoE Expert Quantization & Expert LoRA (26B-A4B only)
The 26B-A4B config uses ScatterMoE kernels via the transformers `ExpertsInterface` and quantizes expert weights on load. To learn about expert quantization, expert LoRA targeting, and related limitations, see the [MoE Expert Quantization](https://docs.axolotl.ai/docs/expert_quantization.html) docs.
## Flex Attention
Reduce ~40% VRAM (at the cost of up to half throughput) by setting the below (shown in `examples/gemma4/31b-qlora-flex.yaml`):
```yaml
torch_compile: true
flex_attention: true
```
This works for both the MoE and Dense model.
## Limitations
- **Flash Attention**: FA2 (max head_dim=256) and FA4 (max head_dim=128) cannot support Gemma 4's `global_head_dim=512`. Use SDP or flex attention instead.
- **LoRA kernels**: Not supported due to KV-sharing layers.
- **lora_target_linear**: Incompatible for multimodal models — use `lora_target_modules` with a regex to restrict LoRA to the text backbone.
### TIPS
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- You can run full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy and has not been tested.
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [Gemma 4 Blog](https://huggingface.co/blog/gemma4)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,62 @@
# Gemma 4 E2B Vision LoRA
#
# Fine-tuning LM LoRA adapters on multimodal Gemma4 with vision/multimodal modules frozen.
# Uses the base ProcessingStrategy (auto-detects image_token from processor).
base_model: google/gemma-4-E2B-it
processor_type: AutoProcessor
freeze_mm_modules: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
# Required for vision/multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: gemma4
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:100]
val_set_size: 0
output_dir: ./outputs/gemma4-e2b-vision-lora
adapter: lora
sequence_len: 2048
pad_to_sequence_len: false
lora_r: 16
lora_alpha: 32
lora_dropout: 0
# Target language model only — vision encoder is frozen via freeze_mm_modules
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
max_steps: 10
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
weight_decay: 0.0
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

View File

@@ -1,5 +1,15 @@
base_model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.liger.LigerPlugin
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_rms_norm_gated: true
# LoRA kernel patches are incompatible with this architecture — see README.
lora_mlp_kernel: false
lora_qkv_kernel: false
@@ -22,8 +32,6 @@ dataset_prepared_path: last_run_prepared
sequence_len: 4096
sample_packing: true
use_cut_cross_entropy: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
@@ -31,16 +39,16 @@ lora_r: 16
lora_alpha: 32
lora_dropout: 0.0
lora_target_modules:
# Attention projection layers (present in ~12 attention layers out of 88)
- q_proj
- k_proj
- v_proj
- o_proj
# To also train MoE expert weights, add them via lora_target_parameters
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
# lora_target_parameters:
# - up_proj
# - down_proj
# To also train MoE expert weights, add them via lora_target_parameters
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
# lora_target_parameters:
# - up_proj
# - down_proj
wandb_project:
wandb_entity:

View File

@@ -1,6 +1,16 @@
# See examples/nemotron-h/README.md for architecture notes and requirements.
base_model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.liger.LigerPlugin
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_rms_norm_gated: true
# LoRA kernel patches are incompatible with this architecture — see README.
lora_mlp_kernel: false
lora_qkv_kernel: false
@@ -23,8 +33,6 @@ dataset_prepared_path: last_run_prepared
sequence_len: 4096
sample_packing: true
use_cut_cross_entropy: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
@@ -36,11 +44,12 @@ lora_target_modules:
- k_proj
- v_proj
- o_proj
# To also train MoE expert weights, add them via lora_target_parameters
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
# lora_target_parameters:
# - up_proj
# - down_proj
# To also train MoE expert weights, add them via lora_target_parameters
# (they are 3D nn.Parameter tensors, not nn.Linear — no gate_proj):
# lora_target_parameters:
# - up_proj
# - down_proj
wandb_project:
wandb_entity:

View File

@@ -26,8 +26,8 @@ sample_packing: true
# Freeze vision encoder
unfrozen_parameters:
- model\.language_model\..*
- lm_head\..*
- model.language_model.*
- lm_head.*
wandb_project:
wandb_entity:

View File

@@ -0,0 +1,62 @@
# Qwen 3.5 35B-A3B MoE Vision LoRA
#
# Vision fine-tuning of the hybrid DeltaNet + Attention MoE model.
# 256 experts, 8 active per token, with early-fusion vision support.
base_model: Qwen/Qwen3.5-35B-A3B
processor_type: AutoProcessor
# Required for vision/multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: qwen3_5
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:100]
val_set_size: 0
output_dir: ./outputs/qwen35-35b-a3b-vision-lora
adapter: lora
sequence_len: 4096
pad_to_sequence_len: false
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
max_steps: 10
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
weight_decay: 0.0
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

View File

@@ -10,15 +10,15 @@ liger-kernel==0.7.0
packaging==26.0
huggingface_hub>=1.1.7
peft>=0.18.1
peft>=0.19.0,<0.20.0
tokenizers>=0.22.1
transformers==5.5.0
transformers==5.5.4
accelerate==1.13.0
datasets==4.5.0
datasets>=4.8.4,<4.9.0
deepspeed>=0.18.6,<0.19.0
trl==0.29.0
hf_xet==1.3.2
kernels==0.12.2
trl==1.1.0
hf_xet==1.4.3
kernels==0.13.0
fla-core==0.4.1
flash-linear-attention==0.4.1

1518
scripts/analyze_profile.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"'
)

View File

@@ -89,7 +89,7 @@ def parse_requirements(extras_require_map):
]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm>=0.17.1"]
extras_require_map["vllm"] = ["vllm>=0.19.1"]
elif (major, minor) >= (2, 9):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [

View File

@@ -0,0 +1,108 @@
"""Bundled agent documentation for axolotl.
These docs are optimized for consumption by AI coding agents.
The source of truth is docs/agents/*.md and AGENTS.md in the repo root.
This module resolves those paths at runtime — no files are duplicated
into the package.
For pip-only installs (no repo checkout), run `axolotl fetch docs` first
to download the docs locally.
"""
from pathlib import Path
# Topic name -> (filename in docs/agents/, fallback filename for AGENTS.md)
TOPICS = {
"overview": "AGENTS.md",
"sft": "docs/agents/sft.md",
"grpo": "docs/agents/grpo.md",
"preference_tuning": "docs/agents/preference_tuning.md",
"reward_modelling": "docs/agents/reward_modelling.md",
"pretraining": "docs/agents/pretraining.md",
"model_architectures": "docs/agents/model_architectures.md",
"new_model_support": "docs/agents/new_model_support.md",
}
def _find_repo_root() -> Path | None:
"""Walk up from this file to find the repo root (contains AGENTS.md)."""
# In an editable install or repo checkout, walk up from
# src/axolotl/cli/agent_docs/ to find the repo root
current = Path(__file__).resolve().parent
while current != current.parent:
if (current / "AGENTS.md").exists() and (current / "docs" / "agents").is_dir():
return current
current = current.parent
return None
def _find_docs_dir() -> Path | None:
"""Find a fetched docs directory (from `axolotl fetch docs`)."""
# axolotl fetch docs --dest defaults to ./docs/ in cwd
cwd_docs = Path.cwd() / "docs" / "agents"
if cwd_docs.is_dir():
return Path.cwd()
return None
def _resolve_path(topic: str) -> Path:
"""Resolve a topic name to the actual file path."""
if topic not in TOPICS:
available = ", ".join(sorted(TOPICS.keys()))
raise FileNotFoundError(f"Unknown topic: {topic!r}. Available: {available}")
relative_path = TOPICS[topic]
# Try repo root first (editable install / repo checkout)
repo_root = _find_repo_root()
if repo_root:
candidate = repo_root / relative_path
if candidate.exists():
return candidate
# Try cwd (fetched docs via `axolotl fetch docs`)
docs_root = _find_docs_dir()
if docs_root:
candidate = docs_root / relative_path
if candidate.exists():
return candidate
# Also check cwd directly for AGENTS.md
if topic == "overview":
cwd_agents = Path.cwd() / "AGENTS.md"
if cwd_agents.exists():
return cwd_agents
raise FileNotFoundError(
f"Could not find {relative_path!r}. "
f"If you installed axolotl via pip, run `axolotl fetch docs` first "
f"to download the documentation."
)
def get_doc(topic: str = "overview") -> str:
"""Return the content of an agent doc by topic name.
Args:
topic: One of the keys in TOPICS, or "overview" (default).
Returns:
The markdown content of the doc.
Raises:
FileNotFoundError: If the topic can't be found.
"""
return _resolve_path(topic).read_text()
def list_topics() -> dict[str, str]:
"""Return a dict of topic name -> first line (title) of each doc."""
result = {}
for topic in sorted(TOPICS.keys()):
try:
path = _resolve_path(topic)
first_line = path.read_text().split("\n", 1)[0].lstrip("# ").strip()
result[topic] = first_line
except FileNotFoundError:
result[topic] = "(not found — run `axolotl fetch docs`)"
return result

View File

@@ -294,7 +294,9 @@ def merge_lora(config: str, **kwargs):
@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.argument(
"directory", type=click.Choice(["examples", "deepspeed_configs", "docs"])
)
@click.option("--dest", help="Destination directory")
def fetch(directory: str, dest: Optional[str]):
"""
@@ -303,9 +305,10 @@ def fetch(directory: str, dest: Optional[str]):
Available directories:
- examples: Example configuration files
- deepspeed_configs: DeepSpeed configuration files
- docs: Full documentation (Quarto markdown files)
Args:
directory: One of `examples`, `deepspeed_configs`.
directory: One of `examples`, `deepspeed_configs`, `docs`.
dest: Optional destination directory.
"""
fetch_from_github(f"{directory}/", dest)
@@ -340,6 +343,112 @@ def delinearize_llama4(model: str, output: str):
do_delinearize_llama4(model, output)
@cli.command("agent-docs")
@click.argument("topic", required=False, default=None)
@click.option("--list", "list_topics", is_flag=True, help="List available topics")
def agent_docs(topic: Optional[str], list_topics: bool):
"""Show agent-optimized documentation.
Prints reference docs designed for AI coding agents.
These docs are bundled with the package — no network access needed.
\b
Examples:
axolotl agent-docs # overview (start here)
axolotl agent-docs grpo # GRPO reference
axolotl agent-docs sft # SFT reference
axolotl agent-docs --list # list all topics
"""
from axolotl.cli.agent_docs import get_doc, list_topics as _list_topics
if list_topics:
for name, title in _list_topics().items():
click.echo(f" {name:25s} {title}")
return
if topic is None:
topic = "overview"
try:
click.echo(get_doc(topic))
except FileNotFoundError as exc:
raise click.BadParameter(str(exc)) from exc
@cli.command("config-schema")
@click.option(
"--format",
"output_format",
type=click.Choice(["json", "yaml"]),
default="json",
help="Output format (default: json)",
)
@click.option("--field", help="Show schema for a specific field only")
def config_schema(output_format: str, field: Optional[str]):
"""Dump the full config JSON schema.
Useful for AI agents and tooling to discover all available config options,
their types, defaults, and descriptions.
\b
Examples:
axolotl config-schema # full JSON schema
axolotl config-schema --format yaml # YAML format
axolotl config-schema --field adapter # single field
"""
import json
try:
schema = AxolotlInputConfig.model_json_schema()
except (TypeError, ValueError, AttributeError) as exc:
# Fallback: dump field names, types, and defaults when full schema
# generation fails (e.g. torch.dtype not JSON-serializable)
LOG.warning(
"Full JSON schema generation failed, using simplified fallback: %s", exc
)
fields = {}
for name, field_info in AxolotlInputConfig.model_fields.items():
entry = {}
if field_info.description:
entry["description"] = field_info.description
if field_info.default is not None:
try:
json.dumps(field_info.default)
entry["default"] = field_info.default
except (TypeError, ValueError):
entry["default"] = str(field_info.default)
annotation = field_info.annotation
if annotation is not None:
entry["type"] = str(annotation)
fields[name] = entry
schema = {
"properties": fields,
"_note": "simplified schema (full generation failed)",
}
if field:
props = schema.get("properties", {})
if field not in props:
# Try case-insensitive match
matches = [k for k in props if k.lower() == field.lower()]
if matches:
field = matches[0]
else:
raise click.BadParameter(
f"Unknown field: {field!r}. "
f"Omit --field to dump the full schema, "
f"or pipe to jq: axolotl config-schema | jq '.properties | keys'"
)
schema = {field: props[field]}
if output_format == "yaml":
import yaml # pylint: disable=import-outside-toplevel
click.echo(yaml.dump(schema, default_flow_style=False, sort_keys=False))
else:
click.echo(json.dumps(schema, indent=2))
cli.add_command(lm_eval)

View File

@@ -115,6 +115,7 @@ def _do_merge_lora_efficient(*, cfg: DictDefault) -> None:
simulate_nf4_experts=simulate_nf4_experts,
nf4_blocksize=nf4_blocksize,
nf4_double_quant=nf4_double_quant,
trust_remote_code=bool(getattr(cfg, "trust_remote_code", False)),
)
LOG.debug("Memory-efficient LoRA merge completed successfully!")

View File

@@ -17,6 +17,93 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _build_layer_type_map(
base_model_path: Path, trust_remote_code: bool = False
) -> dict[str, str]:
"""Build a map of module_name -> layer_type using a meta-device model.
Instantiates the model architecture on the meta device (zero memory)
to inspect which modules are Linear vs Conv1d/Conv2d/Conv3d.
This avoids relying on weight tensor ndim heuristics.
"""
import json as _json
import torch.nn as nn
from transformers import AutoConfig
config_path = base_model_path / "config.json"
if not config_path.exists():
return {}
try:
with open(config_path) as f:
model_config = _json.load(f)
except (OSError, _json.JSONDecodeError):
return {}
architectures = model_config.get("architectures", [])
if not architectures:
return {}
try:
config = AutoConfig.from_pretrained(
str(base_model_path), trust_remote_code=trust_remote_code
)
except Exception:
LOG.debug("Could not load config for layer type introspection")
return {}
# Determine the right Auto class from architectures
from transformers import (
AutoModel,
AutoModelForCausalLM,
)
auto_classes = [AutoModelForCausalLM, AutoModel]
try:
from transformers import AutoModelForImageTextToText
auto_classes.insert(0, AutoModelForImageTextToText)
except ImportError:
pass
model = None
for auto_cls in auto_classes:
try:
with torch.device("meta"):
model = auto_cls.from_config(
config, trust_remote_code=trust_remote_code
)
break
except Exception: # noqa: BLE001
LOG.debug(
"Could not instantiate meta model with %s, trying next",
auto_cls.__name__,
)
if model is None:
LOG.debug("Could not instantiate meta model for layer type introspection")
return {}
layer_types = {}
for name, module in model.named_modules():
if isinstance(module, nn.Conv3d):
layer_types[name] = "Conv3d"
elif isinstance(module, nn.Conv2d):
layer_types[name] = "Conv2d"
elif isinstance(module, nn.Conv1d):
layer_types[name] = "Conv1d"
elif isinstance(module, nn.Linear):
layer_types[name] = "Linear"
del model
LOG.debug(
f"Layer type map: {len(layer_types)} modules "
f"({sum(1 for v in layer_types.values() if 'Conv' in v)} conv layers)"
)
return layer_types
def _simulate_nf4_roundtrip(
tensor: torch.Tensor,
blocksize: Optional[int] = None,
@@ -191,6 +278,7 @@ def _build_peft_layer_and_get_delta(
adapter_name: str = "default",
is_param_wrapper: bool = False,
magnitude: Optional[torch.Tensor] = None,
layer_type: Optional[str] = None,
) -> torch.Tensor:
"""
Use PEFT's own layer classes to compute the LoRA delta weight.
@@ -211,7 +299,7 @@ def _build_peft_layer_and_get_delta(
out_features = lora_b.shape[0]
lora_alpha = lora_config_dict.get("lora_alpha", lora_config_dict.get("r", 1))
use_rslora = bool(lora_config_dict.get("use_rslora", False))
use_dora = bool(lora_config_dict.get("use_dora", False)) and magnitude is not None
use_dora = bool(lora_config_dict.get("use_dora", False))
if is_param_wrapper:
from peft.tuners.lora.layer import ParamWrapper
@@ -227,18 +315,106 @@ def _build_peft_layer_and_get_delta(
"weight", nn.Parameter(base_tensor.clone(), requires_grad=False)
)
# ParamWrapper rejects dropout/fan_in_fan_out/lora_bias/use_dora, so
# build a minimal config with only the fields it accepts.
pw_config = LoraConfig(
r=r,
lora_alpha=lora_alpha,
lora_dropout=0.0,
fan_in_fan_out=False,
use_rslora=use_rslora,
use_dora=False,
lora_bias=False,
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
layer = ParamWrapper(
fake,
adapter_name=adapter_name,
parameter_name="weight",
config=pw_config,
r=r,
lora_alpha=lora_alpha,
use_rslora=use_rslora,
)
layer.lora_A[adapter_name].weight.data = lora_a
layer.lora_B[adapter_name].weight.data = lora_b
return layer.get_delta_weight(adapter_name)
elif (
layer_type and "Conv" in layer_type or (layer_type is None and lora_a.ndim > 2)
):
# Conv layer detected via model introspection (or ndim fallback)
from peft.tuners.lora import layer as peft_lora_layer
# Determine conv type from layer_type map or fall back to ndim
if layer_type and "Conv" in layer_type:
conv_type: str = layer_type
else:
ndim = lora_a.ndim
_conv_map = {3: "Conv1d", 4: "Conv2d", 5: "Conv3d"}
if ndim not in _conv_map:
raise ValueError(
f"Unsupported LoRA weight dimensionality {ndim} for conv layer"
)
conv_type = _conv_map[ndim]
LOG.warning(
f"Using ndim-based fallback for conv detection (ndim={ndim}). "
f"Consider providing layer_type from meta-device introspection."
)
conv_cls_map = {"Conv1d": nn.Conv1d, "Conv2d": nn.Conv2d, "Conv3d": nn.Conv3d}
ConvCls = conv_cls_map[conv_type]
PeftConvCls = getattr(peft_lora_layer, conv_type)
# Reconstruct conv parameters from base tensor and lora_a shapes
# base_tensor: [out_channels, in_channels/groups, *kernel_size]
# lora_a: [r, in_channels/groups, *kernel_size]
# lora_b: [out_channels, r, *ones]
out_channels = base_tensor.shape[0]
in_channels = base_tensor.shape[1]
kernel_size = tuple(base_tensor.shape[2:])
stride = (1,) * (base_tensor.ndim - 2)
padding = (0,) * (base_tensor.ndim - 2)
base_layer = ConvCls(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=False,
)
base_layer.weight.data = base_tensor.clone()
conv_config = LoraConfig(
r=r_total,
lora_alpha=lora_alpha,
use_rslora=use_rslora,
use_dora=use_dora,
)
layer = PeftConvCls(
base_layer,
adapter_name=adapter_name,
config=conv_config,
r=r_total,
lora_alpha=lora_alpha,
)
layer.lora_A[adapter_name].weight.data = lora_a
layer.lora_B[adapter_name].weight.data = lora_b
if use_dora:
if magnitude is None:
raise ValueError(
f"DoRA merge requires a magnitude vector but none was found "
f"for conv layer (adapter={adapter_name}). Check that the "
f"adapter checkpoint contains lora_magnitude_vector weights."
)
mag_layer = layer.lora_magnitude_vector[adapter_name]
mag_layer.weight = nn.Parameter(magnitude)
layer.merge(adapter_names=[adapter_name])
return base_layer.weight.data - base_tensor
return layer.get_delta_weight(adapter_name)
else:
from peft.tuners.lora.layer import Linear as LoraLinear
@@ -251,15 +427,20 @@ def _build_peft_layer_and_get_delta(
or lora_config_dict.get("lora_fan_in_fan_out", False)
)
layer = LoraLinear(
base_layer,
adapter_name=adapter_name,
linear_config = LoraConfig(
r=r_total,
lora_alpha=lora_alpha,
fan_in_fan_out=fan_in_fan_out,
use_rslora=use_rslora,
use_dora=use_dora,
)
layer = LoraLinear(
base_layer,
adapter_name=adapter_name,
config=linear_config,
r=r_total,
lora_alpha=lora_alpha,
)
layer.lora_A[adapter_name].weight.data = lora_a
layer.lora_B[adapter_name].weight.data = lora_b
@@ -267,6 +448,12 @@ def _build_peft_layer_and_get_delta(
# DoRA merges magnitude normalization into the weight directly.
# Use PEFT's merge() which handles DoRA internally, then
# compute the delta as merged_weight - original_weight.
if magnitude is None:
raise ValueError(
f"DoRA merge requires a magnitude vector but none was found "
f"for linear layer (adapter={adapter_name}). Check that the "
f"adapter checkpoint contains lora_magnitude_vector weights."
)
mag_layer = layer.lora_magnitude_vector[adapter_name]
mag_layer.weight = nn.Parameter(magnitude)
layer.merge(adapter_names=[adapter_name])
@@ -382,6 +569,7 @@ def _merge_tensor_with_lora(
nf4_double_quant: bool = True,
use_dora: bool = False,
weight_renamings: Optional[Dict[str, str]] = None,
layer_type_map: Optional[Dict[str, str]] = None,
) -> tuple[torch.Tensor, bool]:
"""
Helper function to merge a single tensor with its corresponding LoRA weights.
@@ -426,12 +614,30 @@ def _merge_tensor_with_lora(
if use_dora
else None
)
# Look up layer type from meta-device model introspection
_layer_type = None
if layer_type_map:
mod_path = key.rsplit(".weight", 1)[0] if key.endswith(".weight") else key
_layer_type = layer_type_map.get(mod_path)
# Try common prefix variations (e.g. with/without "model." prefix)
if _layer_type is None:
for prefix in [
"model.",
"model.language_model.",
"model.language_model.model.",
]:
_layer_type = layer_type_map.get(prefix + mod_path)
if _layer_type:
break
delta = _build_peft_layer_and_get_delta(
lora_a.to(device),
lora_b.to(device),
lora_config_dict,
tensor.to(device),
magnitude=magnitude.to(device) if magnitude is not None else None,
layer_type=_layer_type,
)
merged_tensor = (
(tensor.to(device).to(torch.float32) + delta.to(torch.float32))
@@ -556,6 +762,7 @@ def _fuse_and_unfuse_with_merge(
nf4_double_quant: bool = True,
use_dora: bool = False,
weight_renamings: Optional[Dict[str, str]] = None,
layer_type_map: Optional[Dict[str, str]] = None,
) -> tuple[Dict[str, torch.Tensor], int, set]:
"""
For tensors matching WeightConverter patterns (MoE expert weights):
@@ -696,12 +903,32 @@ def _fuse_and_unfuse_with_merge(
if use_dora
else None
)
# Look up layer type for the fused key
_layer_type = None
if layer_type_map:
mod_path = (
fused_key.rsplit(".weight", 1)[0]
if fused_key.endswith(".weight")
else fused_key
)
_layer_type = layer_type_map.get(mod_path)
if _layer_type is None:
for prefix in [
"model.",
"model.language_model.",
"model.language_model.model.",
]:
_layer_type = layer_type_map.get(prefix + mod_path)
if _layer_type:
break
delta = _build_peft_layer_and_get_delta(
lora_a.to(device),
lora_b.to(device),
lora_config_dict,
fused_tensor.to(device),
magnitude=magnitude.to(device) if magnitude is not None else None,
layer_type=_layer_type,
)
fused_tensor = (
(
@@ -740,6 +967,7 @@ def merge_lora_sharded_efficient(
simulate_nf4_experts: bool = False,
nf4_blocksize: Optional[int] = None,
nf4_double_quant: bool = True,
trust_remote_code: bool = False,
) -> None:
"""
Memory-efficient LoRA merging that processes shards individually
@@ -750,6 +978,8 @@ def merge_lora_sharded_efficient(
simulate_nf4_experts: Apply NF4 roundtrip only to MoE expert tensors
(for quantize_moe_experts). Expert tensors are identified by having
"expert" in the key name and ndim >= 3.
trust_remote_code: Whether to trust remote code when loading model
config for layer-type introspection. Defaults to False for safety.
"""
base_model_path = Path(base_model_path)
lora_adapter_path = Path(lora_adapter_path)
@@ -780,6 +1010,10 @@ def merge_lora_sharded_efficient(
use_dora = bool(lora_config_dict.get("use_dora", False))
# Build layer type map via meta-device model introspection
layer_type_map = _build_layer_type_map(
base_model_path, trust_remote_code=trust_remote_code
)
unsupported_methods = []
# Check for AdaLoRA (Adaptive LoRA)
@@ -904,6 +1138,7 @@ def merge_lora_sharded_efficient(
nf4_double_quant=nf4_double_quant,
use_dora=use_dora,
weight_renamings=weight_renamings,
layer_type_map=layer_type_map,
)
merged_count += fused_merged
@@ -926,6 +1161,7 @@ def merge_lora_sharded_efficient(
nf4_double_quant=nf4_double_quant,
use_dora=use_dora,
weight_renamings=weight_renamings,
layer_type_map=layer_type_map,
)
merged_tensors[key] = merged_tensor
if was_merged:

View File

@@ -41,6 +41,7 @@ from axolotl.utils.callbacks import (
GCCallback,
SaveAxolotlConfigtoWandBCallback,
SaveModelOnFirstStepCallback,
SkipEvalOnResumeCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.distributed import build_parallelism_config
@@ -118,6 +119,9 @@ class TrainerBuilderBase(abc.ABC):
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.resume_from_checkpoint:
callbacks.append(SkipEvalOnResumeCallback())
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))

View File

@@ -100,6 +100,27 @@ class AxolotlTrainer(
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
# Gemma4 (and similar multimodal models) declare **kwargs in forward() for
# extra inputs like mm_token_type_ids. HF Trainer interprets VAR_KEYWORD as
# "the model handles num_items_in_batch internally" and skips the loss ÷
# gradient_accumulation_steps normalisation, which inflates the *logged* loss
# (the gradient itself is still correct). Override to False when the model
# doesn't actually consume num_items_in_batch.
if self.model_accepts_loss_kwargs:
model_to_check = self.accelerator.unwrap_model(self.model)
if hasattr(model_to_check, "base_model"): # PEFT wrapper
model_to_check = model_to_check.base_model
if hasattr(model_to_check, "model"):
model_to_check = model_to_check.model
fwd = getattr(model_to_check, "forward", None)
if fwd is not None:
import inspect
params = inspect.signature(fwd).parameters
if "num_items_in_batch" not in params:
self.model_accepts_loss_kwargs = False
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
@@ -383,13 +404,29 @@ class AxolotlTrainer(
# Gemma4 requires mm_token_type_ids during training (even for text-only).
# Inject zeros (= text token type) when not provided by the data collator.
# Use unwrap_model to handle DDP/FSDP wrappers that don't proxy .config.
_unwrapped = self.accelerator.unwrap_model(model)
_model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None)
if (
"mm_token_type_ids" not in inputs
and "input_ids" in inputs
and getattr(getattr(model, "config", None), "model_type", None) == "gemma4"
and _model_type == "gemma4"
):
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
# Gemma4 (and Gemma3): transformers' masking_utils detects packed sequences
# from position_ids, but only when attention_mask is None. When sample
# packing is active the collator provides an all-ones attention_mask that
# prevents this detection — remove it so the model builds the correct
# per-sequence causal masks.
if (
self.args.sample_packing
and _model_type in ("gemma4", "gemma3")
and "attention_mask" in inputs
and "position_ids" in inputs
):
del inputs["attention_mask"]
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
@@ -398,6 +435,23 @@ class AxolotlTrainer(
num_items_in_batch=num_items_in_batch,
)
# Gemma4ForConditionalGeneration computes loss with a manual
# nn.CrossEntropyLoss() that bypasses proper num_items_in_batch
# normalization and does redundant attention_mask filtering.
# Compute loss externally using the standard loss_function instead.
if _model_type == "gemma4" and "labels" in inputs:
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
unwrapped = self.accelerator.unwrap_model(model)
vocab_size = unwrapped.config.get_text_config().vocab_size
loss = unwrapped.loss_function(
logits, labels, vocab_size, num_items_in_batch=num_items_in_batch
)
if return_outputs:
return loss, outputs
return loss
return super().compute_loss(
model,
inputs,
@@ -410,6 +464,21 @@ class AxolotlTrainer(
LOG.info("Running evaluation step...")
return super().evaluate(*args, **kwargs)
@override
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
# Gemma4 requires mm_token_type_ids even during evaluation.
_unwrapped = self.accelerator.unwrap_model(model)
_model_type = getattr(getattr(_unwrapped, "config", None), "model_type", None)
if (
"mm_token_type_ids" not in inputs
and "input_ids" in inputs
and _model_type == "gemma4"
):
inputs["mm_token_type_ids"] = torch.zeros_like(inputs["input_ids"])
return super().prediction_step(
model, inputs, prediction_loss_only, ignore_keys=ignore_keys
)
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {}

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"
```
## Usage
@@ -44,6 +44,7 @@ plugins:
- gemma3_text
- gemma3n
- gemma3n_text
- gemma4
- glm
- glm4
- glm4_moe

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`'
'`pip uninstall -y cut-cross-entropy && pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88"`'
)

View File

@@ -146,10 +146,6 @@ Gemma 4 (e.g. `google/gemma-4-26B-A4B`) has a unique hybrid MoE architecture:
Because there is no SparseMoeBlock class to patch, Gemma 4 uses a different integration path: we register `"scattermoe"` as a custom implementation in the transformers `ExpertsInterface`, and set `experts_implementation: scattermoe` in the config. The `@use_experts_implementation` decorator on `Gemma4TextExperts` then dispatches to our ScatterMoE kernel automatically. The router is untouched — it runs as-is.
**Important limitations:**
- **Flash Attention 2 is not supported** — Gemma 4 uses `global_head_dim: 512` for full attention layers, which exceeds FA2's maximum head dimension of 256. Use `sdp_attention: true` instead.
- **Multimodal model**: Gemma 4 includes vision and audio encoders. For text-only SFT, use `lora_target_linear_modules` with a regex to restrict LoRA to the text backbone (e.g. `language_model\.model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj`).
## Limitations
- **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`).

View File

@@ -53,28 +53,6 @@ class KernelsArgs(BaseModel):
return data
@model_validator(mode="before")
@classmethod
def warn_sonicmoe_lora_overhead(cls, data):
if data.get("use_sonicmoe") is True and data.get("adapter") in (
"lora",
"qlora",
):
lora_target = data.get("lora_target_modules") or []
lora_linear = data.get("lora_target_linear_modules") or []
targets = (
lora_target if isinstance(lora_target, list) else [lora_target]
) + (lora_linear if isinstance(lora_linear, list) else [lora_linear])
expert_keywords = ("gate_up_proj", "down_proj", "experts")
if any(kw in t for t in targets for kw in expert_keywords):
LOG.info(
"SonicMoE + LoRA on expert modules uses runtime weight materialization "
"(W_eff = W + scaling*B@A per forward). This has slightly higher overhead "
"than ScatterMoE's fused Triton LoRA kernels but works with any CUTLASS kernel."
)
return data
@model_validator(mode="before")
@classmethod
def disable_mlp_kernel(cls, data):

View File

@@ -222,6 +222,56 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
swiglu=cfg.liger_glu_activation,
)
elif cfg.model_config_type in ("gemma4", "gemma4_text"):
# Gemma4: offset=0 (NOT 1 like Gemma3), in_place=False required for
# gradient checkpointing compatibility, RoPE incompatible (separate q/k).
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from transformers.models.gemma4 import modeling_gemma4
if cfg.liger_rms_norm:
_OrigGemma4RMSNorm = modeling_gemma4.Gemma4RMSNorm
class _LigerGemma4RMSNorm(LigerRMSNorm):
"""LigerRMSNorm for Gemma4 with in_place=False and with_scale support."""
def __new__(cls, dim, eps=1e-6, with_scale=True):
if not with_scale:
return _OrigGemma4RMSNorm(dim, eps, with_scale=False)
return super().__new__(cls)
def __init__(self, dim, eps=1e-6, with_scale=True):
if not with_scale:
return
# offset=0.0 (standard), in_place=False (gradient checkpointing safe)
super().__init__(
dim, eps, offset=0.0, casting_mode="llama", in_place=False
)
modeling_gemma4.Gemma4RMSNorm = _LigerGemma4RMSNorm
if cfg.liger_glu_activation:
class _LigerGemma4MLP(LigerGEGLUMLP):
def __init__(self, config, layer_idx=None):
super().__init__(config)
modeling_gemma4.Gemma4TextMLP = _LigerGemma4MLP
if cfg.liger_rope:
LOG.warning(
"Liger RoPE is not compatible with Gemma4 (separate q/k application). Skipping."
)
if cfg.liger_layer_norm:
modeling_gemma4.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
modeling_gemma4.nn.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
LOG.warning(
"Liger fused linear cross entropy is not compatible with Gemma4. Skipping."
)
LOG.info(
f"Applied Liger kernels for gemma4: "
f"rms_norm={cfg.liger_rms_norm}, glu={cfg.liger_glu_activation}, "
f"rope=False (incompatible), layer_norm={cfg.liger_layer_norm}"
)
elif cfg.liger_fused_linear_cross_entropy:
try:
from .models.base import patch_lce_forward

View File

@@ -0,0 +1,529 @@
"""
Fused RMSNorm + RoPE Triton kernel for Gemma 4.
Fuses three operations into one kernel launch:
1. RMSNorm: x_norm = (x / sqrt(mean(x^2) + eps)) * weight
2. RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
3. (optional) RMSNorm without scale (for v_norm)
This eliminates two intermediate tensor materializations per Q/K path;
churn from rotate_half / apply_rotary_pos_emb.
Shapes:
X: (rows, head_dim) — flattened from (batch, seq_len, num_heads, head_dim)
W: (head_dim,) — RMSNorm weight (None for with_scale=False)
cos: (rows, head_dim) — flattened from (batch, seq_len, 1, head_dim) after broadcast
sin: (rows, head_dim) — same as cos
"""
import math
import operator
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import (
calculate_settings,
compare_version,
ensure_contiguous,
torch_to_triton_dtype,
)
from liger_kernel.utils import is_npu_available
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
from triton.language.extra.libdevice import rsqrt
except ModuleNotFoundError:
from triton.language.extra.cuda.libdevice import rsqrt
else:
from triton.language.math import rsqrt
@triton.jit
def _rms_norm_rope_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
W_ptr,
COS_ptr,
COS_row_stride,
SIN_ptr,
SIN_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_cols,
n_heads,
eps,
HAS_WEIGHT: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Fused forward:
x_norm = x / rms(x) [* weight] (RMSNorm)
y = x_norm * cos + rotate_half(x_norm) * sin (RoPE)
rotate_half swaps first/second halves and negates the first:
rotate_half([a, b]) = [-b, a]
cos/sin are indexed by row_idx // n_heads to handle per-head broadcast
(cos/sin have shape (B*S, D) while X has shape (B*S*H, D)).
"""
row_idx = tl.program_id(0).to(tl.int64)
# cos/sin row: divide by n_heads since cos/sin are (B*S, D)
cs_row_idx = row_idx // n_heads
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
half_dim = n_cols // 2
# Load input row
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
X_dtype = X_row.dtype
X_fp32 = X_row.to(tl.float32)
# RMSNorm: compute 1/rms
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
rstd = rsqrt(mean_sq + eps)
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
# Normalize
X_norm = X_fp32 * rstd
# Apply weight if present (with_scale=True)
if HAS_WEIGHT:
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
X_norm = X_norm * W_row
# RoPE: load cos/sin (broadcast across heads)
cos_row = tl.load(
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
sin_row = tl.load(
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
# rotate_half: for col < half_dim, take -X_norm[col + half_dim]
# for col >= half_dim, take X_norm[col - half_dim]
rot_offsets = tl.where(
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
)
rot_mask = rot_offsets < n_cols
X_rot = tl.load(
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_mask & mask, other=0
).to(tl.float32)
# Re-normalize the rotated values
X_rot_norm = X_rot * rstd
if HAS_WEIGHT:
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_mask & mask, other=0).to(
tl.float32
)
X_rot_norm = X_rot_norm * W_rot
# Negate the first half (rotate_half negates x2, which becomes the first half)
sign = tl.where(col_offsets < half_dim, -1.0, 1.0)
X_rot_norm = X_rot_norm * sign
# Final RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
Y_row = X_norm * cos_row + X_rot_norm * sin_row
tl.store(
Y_ptr + row_idx * Y_row_stride + col_offsets,
Y_row.to(X_dtype),
mask=mask,
)
@triton.jit
def _rms_norm_rope_backward_kernel(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
W_ptr,
COS_ptr,
COS_row_stride,
SIN_ptr,
SIN_row_stride,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_rows,
n_cols,
n_heads,
rows_per_program,
HAS_WEIGHT: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Backward for Y = RoPE(RMSNorm(X, W))
cos/sin indexed by row_idx // n_heads for per-head broadcast.
"""
row_block_id = tl.program_id(0).to(tl.int64)
row_start = row_block_id * rows_per_program
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
half_dim = n_cols // 2
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
if HAS_WEIGHT:
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
for row_idx in range(row_start, row_end):
cs_row_idx = row_idx // n_heads
dY_row = tl.load(
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
X_row = tl.load(
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
cos_row = tl.load(
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
# dN = dY * cos + rotate_half^T(dY * sin)
# rotate_half^T([a, b]) = [b, -a] (adjoint of rotate_half)
#
# Compute rotate_half_transpose(dY * sin) by loading dY and sin at
# rotated offsets directly: dY[rot] * sin[rot] * adj_sign
# This is equivalent to rotating (dY * sin) because the rotation
# just permutes which elements are multiplied.
rot_offsets = tl.where(
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
)
rot_mask = rot_offsets < n_cols
dY_rot = tl.load(
dY_ptr + row_idx * dY_row_stride + rot_offsets,
mask=rot_mask & mask,
other=0,
).to(tl.float32)
sin_rot = tl.load(
SIN_ptr + cs_row_idx * SIN_row_stride + rot_offsets,
mask=rot_mask & mask,
other=0,
).to(tl.float32)
adj_sign = tl.where(col_offsets < half_dim, 1.0, -1.0)
dN = dY_row * cos_row + dY_rot * sin_rot * adj_sign
# Pre-weight normalized: n = rstd * x
n = X_row * rstd
if HAS_WEIGHT:
dW_acc += dN * n
dm = dN * W_row
else:
dm = dN
# RMSNorm backward: dX = rstd * (dm - (1/n_cols) * rstd^2 * dot(dm, X) * X)
dot_dm_x = tl.sum(dm * X_row, axis=0)
dX_row = rstd * (dm - (1.0 / n_cols) * rstd * rstd * dot_dm_x * X_row)
tl.store(
dX_ptr + row_idx * dX_row_stride + col_offsets,
dX_row.to(X_dtype),
mask=mask,
)
if HAS_WEIGHT:
tl.store(
dW_ptr + row_block_id * dW_row_stride + col_offsets,
dW_acc,
mask=mask,
)
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
"""
Args:
X: (B*S*H, head_dim) — contiguous, flattened from (B, S, H, D)
W: (head_dim,) or None — RMSNorm weight
cos: (B*S, head_dim) — position embeddings (broadcast across heads)
sin: (B*S, head_dim) — position embeddings (broadcast across heads)
eps: float
n_heads: int — number of attention heads (for cos/sin indexing)
Returns:
Y, X_saved, RSTD, BLOCK_SIZE, num_warps
"""
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
has_weight = W is not None
Y = torch.empty_like(X)
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
_rms_norm_rope_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
X,
X.stride(0),
W if has_weight else X, # dummy pointer when no weight
cos,
cos.stride(0),
sin,
sin.stride(0),
RSTD,
RSTD.stride(0),
n_cols,
n_heads,
eps,
HAS_WEIGHT=has_weight,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return Y, X, RSTD, BLOCK_SIZE, num_warps
def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_warps):
n_rows, n_cols = dY.shape
has_weight = W is not None
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
rows_per_program = math.ceil(n_rows / sm_count)
dX = torch.empty_like(X)
if has_weight:
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=X.device)
else:
_dW = torch.empty((1, n_cols), dtype=torch.float32, device=X.device)
_rms_norm_rope_backward_kernel[(sm_count,)](
dY,
dY.stride(0),
dX,
dX.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
W if has_weight else X, # dummy
cos,
cos.stride(0),
sin,
sin.stride(0),
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0),
n_rows,
n_cols,
n_heads,
rows_per_program,
HAS_WEIGHT=has_weight,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
dW = _dW.sum(dim=0).to(W.dtype) if has_weight else None
return dX, dW
class FusedRMSNormRoPEFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, X, W, cos, sin, eps, n_heads):
"""
X: (B*S*H, head_dim)
W: (head_dim,) or None
cos: (B*S, head_dim) — broadcast across heads
sin: (B*S, head_dim) — broadcast across heads
n_heads: int
"""
Y, X_saved, RSTD, BLOCK_SIZE, num_warps = rms_norm_rope_forward(
X,
W,
cos,
sin,
eps,
n_heads,
)
ctx.eps = eps
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.n_heads = n_heads
ctx.has_weight = W is not None
ctx.save_for_backward(X_saved, W, cos, sin, RSTD)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, W, cos, sin, RSTD = ctx.saved_tensors
dX, dW = rms_norm_rope_backward(
dY,
X,
W,
cos,
sin,
RSTD,
ctx.n_heads,
ctx.BLOCK_SIZE,
ctx.num_warps,
)
return dX, dW, None, None, None, None
def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
"""
Apply fused RMSNorm + RoPE.
Args:
x: (batch, seq_len, num_heads, head_dim) — after projection + view
weight: (head_dim,) — RMSNorm weight, or None for no-scale norm
cos: (batch, seq_len, head_dim) — from RotaryEmbedding
sin: (batch, seq_len, head_dim) — from RotaryEmbedding
eps: float — RMSNorm epsilon
Returns:
y: (batch, seq_len, num_heads, head_dim) — normalized + rotated
"""
shape = x.shape # (B, S, H, D)
B, S, H, D = shape
# Flatten to 2D: (B*S*H, D)
x_flat = x.reshape(-1, D).contiguous()
# Flatten cos/sin to (B*S, D) — the kernel will handle per-head broadcast
# by dividing the row_idx by H to get the cos/sin row
cos_flat = cos.reshape(B * S, D).contiguous()
sin_flat = sin.reshape(B * S, D).contiguous()
y_flat = FusedRMSNormRoPEFunction.apply(x_flat, weight, cos_flat, sin_flat, eps, H)
return y_flat.view(shape)
@triton.jit
def _rms_norm_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_cols,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""RMSNorm without scale weight: y = x / rms(x)"""
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
X_dtype = X_row.dtype
X_fp32 = X_row.to(tl.float32)
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
rstd = rsqrt(mean_sq + eps)
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
Y_row = X_fp32 * rstd
tl.store(Y_ptr + row_idx * Y_row_stride + col_offsets, Y_row.to(X_dtype), mask=mask)
@triton.jit
def _rms_norm_noscale_backward_kernel(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
RSTD_ptr,
RSTD_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""Backward for y = x * rstd (no weight)."""
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
dY_row = tl.load(
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
X_row = tl.load(
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
dot_dy_x = tl.sum(dY_row * X_row, axis=0)
dX_row = rstd * (dY_row - (1.0 / n_cols) * rstd * rstd * dot_dy_x * X_row)
tl.store(
dX_ptr + row_idx * dX_row_stride + col_offsets, dX_row.to(X_dtype), mask=mask
)
class FusedRMSNormNoScaleFunction(torch.autograd.Function):
"""RMSNorm without learnable scale — used for Gemma4's v_norm."""
@staticmethod
@ensure_contiguous
def forward(ctx, X, eps):
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Y = torch.empty_like(X)
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
_rms_norm_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
X,
X.stride(0),
RSTD,
RSTD.stride(0),
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.save_for_backward(X, RSTD)
ctx.n_cols = n_cols
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, RSTD = ctx.saved_tensors
n_rows = X.shape[0]
dX = torch.empty_like(X)
_rms_norm_noscale_backward_kernel[(n_rows,)](
dY,
dY.stride(0),
dX,
dX.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
RSTD,
RSTD.stride(0),
ctx.n_cols,
BLOCK_SIZE=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
)
return dX, None
def fused_rms_norm_noscale(x, eps=1e-6):
"""
RMSNorm without scale for v_norm.
Args:
x: (batch, seq_len, num_heads, head_dim)
Returns:
y: same shape, normalized
"""
shape = x.shape
x_flat = x.reshape(-1, shape[-1])
y_flat = FusedRMSNormNoScaleFunction.apply(x_flat, eps)
return y_flat.view(shape)

View File

@@ -1297,6 +1297,339 @@ def apply_lora_qkv(
return Q, K, V
class LoRA_QK(torch.autograd.Function):
"""Optimized LoRA QK implementation for models where v_proj is None.
Used by models like Gemma4 with attention_k_eq_v=True, where key states are
reused as value states. Only Q and K projections are fused; the caller
returns K a second time as V so that autograd accumulates key+value gradients
into a single dK.
Supports bias, dropout, and DoRA (Weight-Decomposed Low-Rank Adaptation).
"""
@staticmethod
@torch_amp_custom_fwd
def forward(
ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor,
X_drop: torch.Tensor | None,
# Q params
q_weight: torch.Tensor,
q_bias: torch.Tensor | None,
q_quant: QuantState | None,
q_A: torch.Tensor | None,
q_B: torch.Tensor | None,
q_scale: float,
q_lora_bias: torch.Tensor | None,
q_magnitude: torch.Tensor | None,
# K params
k_weight: torch.Tensor,
k_bias: torch.Tensor | None,
k_quant: QuantState | None,
k_A: torch.Tensor | None,
k_B: torch.Tensor | None,
k_scale: float,
k_lora_bias: torch.Tensor | None,
k_magnitude: torch.Tensor | None,
# Flags
inplace: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
has_dropout = X_drop is not None
has_dora = q_magnitude is not None
if has_dora:
dtype = X.dtype
X_lora = X_drop if has_dropout else X
# Compute Q with DoRA
Q_base = matmul_lora(X, q_weight, None, q_quant, None, None, None)
Q_lora = _lora_only(X_lora, q_A, q_B, q_scale, q_lora_bias, dtype)
q_mag_scale = _compute_dora_scale(
q_weight, q_quant, q_A, q_B, q_scale, q_magnitude, dtype
)
Q = q_mag_scale.unsqueeze(0) * (Q_base + Q_lora)
if q_bias is not None:
Q = Q + q_bias
# Compute K with DoRA
K_base = matmul_lora(X, k_weight, None, k_quant, None, None, None)
K_lora = _lora_only(X_lora, k_A, k_B, k_scale, k_lora_bias, dtype)
k_mag_scale = _compute_dora_scale(
k_weight, k_quant, k_A, k_B, k_scale, k_magnitude, dtype
)
K = k_mag_scale.unsqueeze(0) * (K_base + K_lora)
if k_bias is not None:
K = K + k_bias
Q_combined = Q_base + Q_lora
K_combined = K_base + K_lora
ctx.save_for_backward(
X,
X_drop if has_dropout else X,
q_A.to(dtype) if q_A is not None else q_A,
q_B.to(dtype) if q_B is not None else q_B,
k_A.to(dtype) if k_A is not None else k_A,
k_B.to(dtype) if k_B is not None else k_B,
q_magnitude,
k_magnitude,
q_mag_scale,
k_mag_scale,
Q_combined,
K_combined,
q_lora_bias,
k_lora_bias,
)
else:
# Standard LoRA (with optional dropout and bias)
Q = matmul_lora(
X,
q_weight,
q_bias,
q_quant,
q_A,
q_B,
q_scale,
X_drop=X_drop,
lora_bias=q_lora_bias,
)
K = matmul_lora(
X,
k_weight,
k_bias,
k_quant,
k_A,
k_B,
k_scale,
X_drop=X_drop,
lora_bias=k_lora_bias,
)
dtype = X.dtype
ctx.save_for_backward(
X,
X_drop if has_dropout else X,
q_A.to(dtype) if q_A is not None else q_A,
q_B.to(dtype) if q_B is not None else q_B,
k_A.to(dtype) if k_A is not None else k_A,
k_B.to(dtype) if k_B is not None else k_B,
q_lora_bias,
k_lora_bias,
)
ctx.scales = (q_scale, k_scale)
ctx.quants = (q_quant, k_quant)
ctx.weights = (q_weight, k_weight)
ctx.inplace = inplace
ctx.has_dropout = has_dropout
ctx.has_dora = has_dora
return Q, K
@staticmethod
@torch_amp_custom_bwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
q_grad: torch.Tensor,
k_grad: torch.Tensor,
):
q_weight, k_weight = ctx.weights
q_quant, k_quant = ctx.quants
q_scale, k_scale = ctx.scales
has_dropout = ctx.has_dropout
has_dora = ctx.has_dora
if has_dora:
(
X,
X_lora,
A_q,
B_q,
A_k,
B_k,
q_magnitude,
k_magnitude,
q_mag_scale,
k_mag_scale,
Q_combined,
K_combined,
q_lora_bias,
k_lora_bias,
) = ctx.saved_tensors
else:
(
X,
X_lora,
A_q,
B_q,
A_k,
B_k,
q_lora_bias,
k_lora_bias,
) = ctx.saved_tensors
q_magnitude = k_magnitude = None
q_mag_scale = k_mag_scale = None
Q_combined = K_combined = None
batch, seq_len = X.shape[:2]
q_grad = q_grad.view(-1, q_grad.shape[-1])
k_grad = k_grad.reshape(-1, k_grad.shape[-1])
X = X.view(-1, X.shape[-1])
X_lora = X_lora.view(-1, X_lora.shape[-1])
d_q_mag = d_k_mag = None
d_q_lora_bias = d_k_lora_bias = None
if has_dora:
Q_combined = Q_combined.view(-1, Q_combined.shape[-1])
K_combined = K_combined.view(-1, K_combined.shape[-1])
d_q_mag = (q_grad * Q_combined).sum(dim=0) * q_mag_scale / q_magnitude
d_k_mag = (k_grad * K_combined).sum(dim=0) * k_mag_scale / k_magnitude
q_grad = q_grad * q_mag_scale.unsqueeze(0)
k_grad = k_grad * k_mag_scale.unsqueeze(0)
# LoRA bias gradients
if q_lora_bias is not None:
d_q_lora_bias = q_scale * q_grad.sum(dim=0)
if k_lora_bias is not None:
d_k_lora_bias = k_scale * k_grad.sum(dim=0)
X_lora_t = X_lora.t()
d_A_q = d_B_q = d_A_k = d_B_k = None
grad_B_q = grad_B_k = None
if A_q is not None and B_q is not None:
grad_B_q = q_grad @ B_q
d_A_q = torch.empty_like(A_q.t())
d_B_q = torch.empty_like(B_q.t())
d_A_q.addmm_(X_lora_t, grad_B_q, alpha=q_scale, beta=0)
d_B_q.addmm_(A_q @ X_lora_t, q_grad, alpha=q_scale, beta=0)
if A_k is not None and B_k is not None:
grad_B_k = k_grad @ B_k
d_A_k = torch.empty_like(A_k.t())
d_B_k = torch.empty_like(B_k.t())
d_A_k.addmm_(X_lora_t, grad_B_k, alpha=k_scale, beta=0)
d_B_k.addmm_(A_k @ X_lora_t, k_grad, alpha=k_scale, beta=0)
# Base path input gradient
out_buffer = X if ctx.inplace else None
q_weight_t = dequantize(q_weight, q_quant)
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
del q_weight_t
k_weight_t = dequantize(k_weight, k_quant)
grad_X.addmm_(k_grad, k_weight_t)
del k_weight_t
# LoRA path input gradient
if has_dropout:
grad_X_drop = torch.zeros_like(X_lora)
if grad_B_q is not None:
grad_X_drop.addmm_(grad_B_q, A_q, alpha=q_scale)
if grad_B_k is not None:
grad_X_drop.addmm_(grad_B_k, A_k, alpha=k_scale)
else:
grad_X_drop = None
if grad_B_q is not None:
grad_X.addmm_(grad_B_q, A_q, alpha=q_scale)
if grad_B_k is not None:
grad_X.addmm_(grad_B_k, A_k, alpha=k_scale)
if d_A_q is not None:
d_A_q = d_A_q.t()
d_B_q = d_B_q.t() # type: ignore[union-attr]
if d_A_k is not None:
d_A_k = d_A_k.t()
d_B_k = d_B_k.t() # type: ignore[union-attr]
grad_X = grad_X.view(batch, seq_len, -1)
if grad_X_drop is not None:
grad_X_drop = grad_X_drop.view(batch, seq_len, -1)
# Return gradients for all forward inputs:
# X, X_drop,
# q: weight, bias, quant, A, B, scale, lora_bias, magnitude
# k: weight, bias, quant, A, B, scale, lora_bias, magnitude
# inplace
return (
grad_X,
grad_X_drop,
# Q
None,
None,
None,
d_A_q,
d_B_q,
None,
d_q_lora_bias,
d_q_mag,
# K
None,
None,
None,
d_A_k,
d_B_k,
None,
d_k_lora_bias,
d_k_mag,
# inplace
None,
)
def apply_lora_qk(
self, X: torch.Tensor, inplace: bool = True
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Applies LoRA to compute Query and Key projections for models where v_proj is None.
When v_proj is None (e.g. Gemma4 attention_k_eq_v), key states are reused as
value states. Returns (Q, K, K) — the caller's patched forward will use K as V.
Because K is returned twice, autograd accumulates gradients from both the key and
value paths into dK before calling LoRA_QK.backward.
Supports bias, dropout, and DoRA.
"""
QW, Qb, QW_quant, QA, QB, QS, Qlb, Qdrop, Qmag = get_lora_parameters(self.q_proj)
KW, Kb, KW_quant, KA, KB, KS, Klb, Kdrop, Kmag = get_lora_parameters(self.k_proj)
# Apply dropout outside autograd.Function (shared mask for Q, K)
X_drop = _apply_dropout(Qdrop, X, self.training)
Q, K = LoRA_QK.apply(
X,
X_drop,
# Q
QW,
Qb,
QW_quant,
QA,
QB,
QS,
Qlb,
Qmag,
# K
KW,
Kb,
KW_quant,
KA,
KB,
KS,
Klb,
Kmag,
# Flags
inplace,
)
return Q, K, K
class LoRA_O(torch.autograd.Function):
"""Optimized LoRA implementation for output projection.

View File

@@ -67,12 +67,165 @@ def find_all_linear_names(model):
return list(lora_module_names)
def _patch_peft_clippable_linear():
"""Patch PEFT to handle Gemma4ClippableLinear which wraps nn.Linear.
Gemma4's vision tower uses ClippableLinear (a thin wrapper around nn.Linear
that clips activations). PEFT doesn't recognise it as a supported layer type,
so we redirect LoRA injection to the inner ``.linear`` child instead.
"""
try:
from transformers.models.gemma4.modeling_gemma4 import (
Gemma4ClippableLinear as _cls,
)
except ImportError:
return
from peft.tuners.lora.model import LoraModel
if getattr(LoraModel, "_axolotl_clippable_patched", False):
return
_orig = LoraModel._create_and_replace
def _patched(
self,
peft_config,
adapter_name,
target,
target_name,
parent,
current_key=None,
**kw,
):
if isinstance(target, _cls):
# Redirect to the inner nn.Linear so PEFT can wrap it normally.
return _orig(
self,
peft_config,
adapter_name,
target.linear,
"linear",
target,
current_key=current_key,
**kw,
)
return _orig(
self,
peft_config,
adapter_name,
target,
target_name,
parent,
current_key=current_key,
**kw,
)
LoraModel._create_and_replace = _patched
LoraModel._axolotl_clippable_patched = True
def _peft_will_auto_convert_target_params(model, lora_config) -> bool:
"""Check whether PEFT will auto-populate target_parameters for this model.
PEFT 0.19's ``convert_peft_config_for_transformers`` rewrites old MoE
``target_modules`` (e.g. ``w1``/``w2``/``w3`` on Mixtral) into
``target_parameters`` (``gate_up_proj``/``down_proj``) because
transformers v5 fused those expert linears into 3D ``nn.Parameter``
tensors. PEFT wraps the resulting 3D params with ``ParamWrapper``,
which rejects ``lora_dropout != 0``. This probe runs the conversion on
a copy of the config so we can detect the situation before
``get_peft_model`` blows up.
"""
if getattr(lora_config, "target_parameters", None):
return False
try:
from peft.utils.transformers_weight_conversion import (
convert_peft_config_for_transformers,
get_model_conversion_mapping,
)
except ImportError:
return False
import copy
probe_cfg = copy.deepcopy(lora_config)
try:
convert_peft_config_for_transformers(
probe_cfg,
model=model,
conversions=get_model_conversion_mapping(model),
)
except Exception: # pylint: disable=broad-except
return False
return bool(getattr(probe_cfg, "target_parameters", None))
def _patch_peft_param_wrapper_dropout():
"""Let PEFT's ``ParamWrapper`` silently accept ``lora_dropout != 0``.
``ParamWrapper`` wraps 3D expert ``nn.Parameter`` tensors and rejects
non-zero dropout because dropout can't be factored out of
``lora_B(lora_A(dropout(x)))`` when the inner op is an expert-indexed
matmul. For mixed configs (attention + MoE experts) this is too
aggressive — the non-expert ``Linear`` LoRA layers *can* apply dropout
and that's usually what the user intended. We pass a copy of the
``LoraConfig`` with ``lora_dropout=0`` only to ``ParamWrapper.__init__``
so it builds with ``nn.Identity`` for its internal dropout slot while
every other layer type still receives the real dropout value.
"""
from peft.tuners.lora.layer import ParamWrapper
if getattr(ParamWrapper, "_axolotl_dropout_patched", False):
return
_orig_init = ParamWrapper.__init__
def _patched_init(
self,
base_layer,
adapter_name,
parameter_name,
config,
*args,
**kwargs,
):
if getattr(config, "lora_dropout", 0):
import copy as _copy
patched_config = _copy.copy(config)
patched_config.lora_dropout = 0.0
return _orig_init(
self,
base_layer,
adapter_name,
parameter_name,
patched_config,
*args,
**kwargs,
)
return _orig_init(
self,
base_layer,
adapter_name,
parameter_name,
config,
*args,
**kwargs,
)
ParamWrapper.__init__ = _patched_init
ParamWrapper._axolotl_dropout_patched = True
def load_lora(
model: PreTrainedModel,
cfg: DictDefault,
inference: bool = False,
config_only: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
_patch_peft_clippable_linear()
lora_target_modules = cfg.lora_target_modules or []
lora_target_parameters = cfg.lora_target_parameters or []
@@ -124,6 +277,7 @@ def load_lora(
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
exclude_modules=getattr(cfg, "lora_exclude_modules", None) or None,
bias="none",
task_type=task_type,
**lora_config_kwargs,
@@ -132,6 +286,20 @@ def load_lora(
if config_only:
return None, lora_config
if getattr(
lora_config, "lora_dropout", 0
) and _peft_will_auto_convert_target_params(model, lora_config):
LOG.warning(
"lora_dropout=%s requested but PEFT will wrap this model's fused "
"MoE expert parameters with ParamWrapper, which cannot apply "
"dropout (the 3D einsum can't factor dropout out of "
"lora_B(lora_A(dropout(x)))). Dropout will still be applied to "
"non-expert LoRA layers (e.g. attention), and expert LoRA layers "
"will use nn.Identity for the dropout slot.",
lora_config.lora_dropout,
)
_patch_peft_param_wrapper_dropout()
rank = int(os.environ.get("LOCAL_RANK", 0))
if (

View File

@@ -547,6 +547,16 @@ class ModelLoader:
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
if self.cfg.model_quantization_config == "FineGrainedFP8Config":
from transformers import FineGrainedFP8Config
fp8_kwargs = {}
if self.cfg.model_quantization_config_kwargs:
fp8_kwargs = self.cfg.model_quantization_config_kwargs
self.model_kwargs["quantization_config"] = FineGrainedFP8Config(
**fp8_kwargs
)
if self.cfg.gptq:
if not hasattr(self.model_config, "quantization_config"):
LOG.warning(
@@ -624,7 +634,14 @@ class ModelLoader:
def _set_attention_config(self):
"""Sample packing uses custom FA2 patch"""
if self.cfg.attn_implementation:
if self.cfg.gemma4_hybrid_attn_impl:
# Load model with flash_attention_2 for sliding window layers;
# global layers will be patched to sdpa post-load.
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = "flash_attention_2"
# Set flash_attention so multipack/sample_packing patches activate
self.cfg.flash_attention = True
elif self.cfg.attn_implementation:
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
elif self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"

View File

@@ -156,6 +156,7 @@ class PatchManager:
# which would clobber any earlier fix.
self._fix_nemotron_h_conversion_mapping()
self._apply_gemma_hybrid_attention(model)
self._finalize_moe_expert_quantization(model)
def apply_post_model_load_patches(self, model: PreTrainedModel):
@@ -165,6 +166,72 @@ class PatchManager:
self._apply_lora_kernel_patch(model)
self._apply_scaling_softmax_patch(model)
def _apply_gemma_hybrid_attention(self, model: PreTrainedModel):
"""Apply hybrid attention: FA2 for sliding window layers, SDPA for global layers.
Gemma 4 has global (full_attention) layers with head_dim=512
which exceeds flash attention's supported size. This patch loads the model
with flash_attention_2 for the sliding window layers (head_dim=256), then
gives each global layer a shallow-copied config with _attn_implementation="sdpa".
"""
if not self.cfg.gemma4_hybrid_attn_impl:
return
import copy
# Navigate to the module that has 'layers' - varies by model structure:
# Gemma4ForConditionalGeneration -> .model (Gemma4Model) -> .language_model (Gemma4TextModel) -> .layers
# Gemma4ForCausalLM -> .model (Gemma4TextModel) -> .layers
layers = None
config_source = None
for candidate in [model, getattr(model, "model", None)]:
if candidate is None:
continue
# Check direct layers
if hasattr(candidate, "layers"):
layers = candidate.layers
config_source = candidate
break
# Check language_model.layers (multimodal wrapper)
lang_model = getattr(candidate, "language_model", None)
if lang_model is not None and hasattr(lang_model, "layers"):
layers = lang_model.layers
config_source = lang_model
break
if layers is None:
LOG.warning(
"gemma4_hybrid_attn_impl: could not find decoder layers in model, skipping"
)
return
config = getattr(config_source, "config", self.model_config)
layer_types = getattr(config, "layer_types", None)
if layer_types is None:
LOG.warning(
"gemma4_hybrid_attn_impl: model config has no 'layer_types', skipping. "
"This feature requires a model with mixed sliding/global attention layers."
)
return
patched_count = 0
for layer_idx, layer in enumerate(layers):
if layer_types[layer_idx] != "sliding_attention":
# Global / full_attention layer - use SDPA instead of FA2
attn_module = getattr(layer, "self_attn", None)
if attn_module is not None and hasattr(attn_module, "config"):
sdpa_config = copy.copy(attn_module.config)
sdpa_config._attn_implementation = "sdpa"
attn_module.config = sdpa_config
patched_count += 1
LOG.info(
"gemma4_hybrid_attn_impl: patched %d global layers to use SDPA "
"(remaining %d sliding layers use flash_attention_2)",
patched_count,
len(layers) - patched_count,
)
def _apply_flash_attention_patches(self):
"""Apply patches related to Flash Attention."""
if self.cfg.xformers_attention and self.cfg.sample_packing:
@@ -324,6 +391,13 @@ class PatchManager:
patch_qwen3_5_vlm_flash_attention()
if self.cfg.model_config_type in ("gemma4", "gemma4_text"):
from axolotl.monkeypatch.models.gemma4.fused_attn import (
patch_gemma4_fused_attn,
)
patch_gemma4_fused_attn()
@staticmethod
def _fix_nemotron_h_conversion_mapping():
"""Remove the spurious embedding→embeddings WeightRenaming from the

View File

@@ -221,14 +221,6 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
# Generic fallback: if tokenizer still has no pad_token, use eos_token
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
LOG.warning(
"Tokenizer does not have a pad_token, falling back to eos_token: %s",
tokenizer.eos_token,
)
additional_special_tokens = None
if cfg.special_tokens:
special_tokens = cfg.special_tokens.to_dict()
@@ -303,6 +295,14 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
{"additional_special_tokens": additional_special_tokens}
)
# Generic fallback: if tokenizer still has no pad_token, use eos_token
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
LOG.warning(
"Tokenizer does not have a pad_token, falling back to eos_token: %s",
tokenizer.eos_token,
)
if is_main_process():
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")

View File

@@ -60,6 +60,13 @@ def fsdp2_load_full_state_dict(
sharded_meta_param.placements,
src_data_rank=0,
)
# Clone the local shard to allow full_tensor to be freed.
if (
sharded_param._local_tensor.untyped_storage().size()
> sharded_param._local_tensor.nelement()
* sharded_param._local_tensor.element_size()
):
sharded_param = sharded_param.clone()
else:
# Non-sharded parameters
if _accelerator.is_main_process:

View File

@@ -86,12 +86,19 @@ def patch_flash_attn_4(model_config=None):
if getattr(fa_utils._lazy_imports, "_axolotl_patched", False):
return
try:
# flash-attn-4>=4.0.0b7
from flash_attn.cute import flash_attn_with_kvcache
except ImportError:
flash_attn_with_kvcache = None
def _patched_lazy_imports(
implementation, attention_wrapper=None, allow_all_kernels=False
):
return (
flash_attn_func,
flash_attn_varlen_func,
flash_attn_with_kvcache,
fa_utils._pad_input,
fa_utils._unpad_input,
)

View File

@@ -16,6 +16,7 @@ from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
apply_lora_o,
apply_lora_qk,
apply_lora_qkv,
)
from axolotl.monkeypatch.utils import detab_code
@@ -111,6 +112,47 @@ QKV_PATCHES = [
else:
key_states = key_states.view(hidden_shape)
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
""".lstrip("\n"),
),
# Gemma4 (transformers >= 5.6): shared_kv_states parameter replaces
# past_key_values.shared_layers, and v_norm added after k_norm.
(
"""
query_states = self.q_proj(hidden_states).view(hidden_shape)
query_states = self.q_norm(query_states)
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
query_states = query_states.transpose(1, 2)
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
# We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
# once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
if self.is_kv_shared_layer:
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
# Device of past layer may be different from current one
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states
""".lstrip("\n"),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape)
query_states = self.q_norm(query_states)
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
query_states = query_states.transpose(1, 2)
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer.
# We cannot simply reuse the cached state if we have a Cache, as sliding layers will not remember the full states in their Cache
# once we are past the sliding window - so we always use `shared_kv_states` instead, even when past_key_values is not None
if self.is_kv_shared_layer:
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
# Device of past layer may be different from current one
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
key_states = key_states.view(hidden_shape)
value_states = value_states.view(hidden_shape) if self.v_proj is not None else key_states
""".lstrip("\n"),
),
]
@@ -483,18 +525,24 @@ def apply_lora_kernel_patches(
if cfg.lora_qkv_kernel:
# Query, key, value patching
# Filter out None projections (e.g. Gemma4 v_proj when attention_k_eq_v=True)
proj_names = ["q_proj", "k_proj", "v_proj"]
layer_modules = [
getattr(self_attn, name)
for name in proj_names
if getattr(self_attn, name, None) is not None
]
has_v_proj = getattr(self_attn, "v_proj", None) is not None
proj_names = (
["q_proj", "k_proj", "v_proj"]
if has_v_proj
else ["q_proj", "k_proj"]
)
layer_modules = [getattr(self_attn, name) for name in proj_names]
can_patch_qkv = all(
hasattr(module, "lora_A") for module in layer_modules
)
if can_patch_qkv:
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
if has_v_proj:
self_attn.apply_qkv = types.MethodType(
apply_lora_qkv, self_attn
)
else:
self_attn.apply_qkv = types.MethodType(apply_lora_qk, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA adapters"

View File

@@ -0,0 +1,147 @@
"""
Gemma 4 fused attention monkeypatch.
Replaces the per-layer RMSNorm + RoPE + transpose sequence with fused Triton
kernels, eliminating intermediate tensor allocations from rotate_half / apply_rotary_pos_emb
Usage:
from axolotl.monkeypatch.models.gemma4.fused_attn import patch_gemma4_fused_attn
patch_gemma4_fused_attn()
"""
import logging
from typing import Callable
import torch
logger = logging.getLogger(__name__)
def _make_fused_forward(original_forward):
"""Create a patched forward that uses fused RMSNorm+RoPE kernels."""
from axolotl.kernels.gemma4_fused_rope import (
fused_rms_norm_noscale,
fused_rms_norm_rope,
)
def fused_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor,
attention_mask: torch.Tensor | None,
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],
past_key_values=None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.gemma4.modeling_gemma4 import (
eager_attention_forward,
)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
eps = self.config.rms_norm_eps
cos, sin = position_embeddings
# ---- Projections ----
# Use apply_qkv if present (LoRA kernel patch), otherwise direct proj
has_lora_qkv = hasattr(self, "apply_qkv")
if has_lora_qkv:
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape)
else:
query_states = self.q_proj(hidden_states).view(hidden_shape)
# ---- Q path: fused q_norm + RoPE ----
query_states = fused_rms_norm_rope(
query_states,
self.q_norm.weight,
cos,
sin,
eps=eps,
)
query_states = query_states.transpose(1, 2)
# ---- K/V path ----
if self.is_kv_shared_layer:
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
if has_lora_qkv:
# apply_qkv already computed k/v projections
key_states = key_states.view(hidden_shape)
value_states = (
value_states.view(hidden_shape)
if self.v_proj is not None
else key_states
)
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
value_states = (
self.v_proj(hidden_states).view(hidden_shape)
if self.v_proj is not None
else key_states
)
# Fused k_norm + RoPE
key_states = fused_rms_norm_rope(
key_states,
self.k_norm.weight,
cos,
sin,
eps=eps,
)
key_states = key_states.transpose(1, 2)
# Fused v_norm (no scale, no RoPE)
value_states = fused_rms_norm_noscale(value_states, eps=eps)
value_states = value_states.transpose(1, 2)
if past_key_values is not None and not self.is_kv_shared_layer:
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx
)
if self.store_full_length_kv:
shared_kv_states[self.layer_idx] = key_states, value_states
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[
self.config._attn_implementation
]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=self.attention_dropout if self.training else 0.0,
scaling=self.scaling,
sliding_window=self.sliding_window,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
return fused_forward
def patch_gemma4_fused_attn():
"""
Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels.
"""
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
original_forward = Gemma4TextAttention.forward
Gemma4TextAttention.forward = _make_fused_forward(original_forward)
logger.info(
"Patched Gemma4TextAttention.forward with fused RMSNorm+RoPE Triton kernels"
)

View File

@@ -315,6 +315,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
self._validate_eot_and_eos_tokens()
# Pre-cache EOT token IDs to avoid re-encoding on every call
self._eot_token_ids = set()
for token in self.eot_tokens:
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
if len(token_ids) == 1:
self._eot_token_ids.add(token_ids[0])
def _validate_eot_and_eos_tokens(self):
"""
- Validates that EOT tokens (or eos_token) are in the chat_template
@@ -471,6 +478,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
content = turn.get("content")
train_turn = turn.get("training")
train_detail = turn.get("training_detail")
reasoning_train_detail = turn.get("reasoning_training_detail")
LOG.debug(
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
@@ -479,8 +487,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
should_train = None
if train_turn is not None:
should_train = train_turn
elif train_detail is not None:
should_train = bool(train_detail)
elif train_detail is not None or reasoning_train_detail is not None:
should_train = bool(train_detail) or bool(reasoning_train_detail)
else:
should_train = self.train_on_inputs or role in self.roles_to_train
@@ -500,15 +508,26 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
continue
thinking_key = self.prompter.template_thinking_key
has_reasoning = thinking_key and turn.get(thinking_key) is not None
has_any_detail = train_detail or reasoning_train_detail
# When train_detail is present and the turn has reasoning_content,
# use content_only=True so find_turn returns content-only boundaries
# (excluding reasoning_content + template separator tokens).
use_content_only = bool(has_any_detail and has_reasoning)
turn_start_idx, turn_end_idx = self.find_turn(
turns=turns, turn_idx=index, tools=tools
turns=turns,
turn_idx=index,
tools=tools,
content_only=use_content_only,
)
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
if train_detail:
# Block multi-content for now
if not isinstance(content, str):
raise ValueError(
"`train_detail` is not supported when `content` is not a string."
@@ -526,7 +545,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
LOG.debug(
f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}"
)
else:
elif not reasoning_train_detail:
# No per-part detail on either field — train the whole span
labels[turn_start_idx:turn_end_idx] = input_ids[
turn_start_idx:turn_end_idx
]
@@ -534,6 +554,32 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
)
# Handle reasoning_content training_detail separately
if should_train and reasoning_train_detail and has_reasoning:
reasoning_text = turn[thinking_key]
if not isinstance(reasoning_text, str):
raise ValueError(
"`reasoning_training_detail` is not supported when reasoning_content is not a string."
)
reasoning_start, reasoning_end = self.find_turn(
turns=turns,
turn_idx=index,
tools=tools,
reasoning_only=True,
)
if reasoning_start != -1 and reasoning_end != -1:
token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore
reasoning_text, reasoning_train_detail
)
LOG.debug(f"Reasoning token offsets: {token_offsets}")
for i, offset in enumerate(token_offsets):
if offset != IGNORE_TOKEN_ID and reasoning_start + i < len(
input_ids
):
labels[reasoning_start + i] = input_ids[reasoning_start + i]
LOG.debug(f"Labels after processing turn {index}: {labels}")
# Handle special tokens (EOT and EOS)
@@ -593,28 +639,31 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def find_first_eot_token(self, input_ids, start_idx):
"""Find the first EOT token in the input_ids starting from start_idx."""
# Get token IDs for all EOT tokens
eot_token_ids = []
for token in self.eot_tokens:
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
if len(token_ids) != 1:
raise ValueError(
f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config."
)
eot_token_ids.append(token_ids[0]) # Use the last token ID if multiple
# Search for any of the EOT token IDs
# Use pre-cached EOT token IDs (computed once in __init__)
for i in range(start_idx, len(input_ids)):
if input_ids[i] in eot_token_ids:
if input_ids[i] in self._eot_token_ids:
return i
return -1
def find_turn(
self, turns: list[dict], turn_idx: int, tools: list[dict] | None = None
self,
turns: list[dict],
turn_idx: int,
tools: list[dict] | None = None,
content_only: bool = False,
reasoning_only: bool = False,
):
"""
Locate the starting and ending indices of the specified turn in a conversation.
Args:
content_only: If True and the turn has reasoning_content (template_thinking_key),
preserve reasoning_content in the dummy turn so the diff only captures the
content field boundaries. This is needed for correct training_detail alignment
when reasoning_content is present.
reasoning_only: If True, preserve content in the dummy turn and replace
reasoning_content with a dummy, so the diff only captures the
reasoning_content field boundaries.
"""
if turn_idx >= len(turns):
@@ -628,10 +677,26 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
):
return -1, -1
empty_turn = {
"role": turns[turn_idx].get("role"),
"content": "[[dummy_message]]",
}
thinking_key = self.prompter.template_thinking_key
if reasoning_only:
# Keep content as-is, replace reasoning with dummy
empty_turn = {
"role": turns[turn_idx].get("role"),
"content": turns[turn_idx].get("content", ""),
}
if thinking_key and thinking_key in turns[turn_idx]:
empty_turn[thinking_key] = "[[dummy_reasoning]]"
else:
empty_turn = {
"role": turns[turn_idx].get("role"),
"content": "[[dummy_message]]",
}
# When content_only is True, copy reasoning_content to the dummy turn so
# the diff only captures the content field (not reasoning + separator).
if content_only and thinking_key and thinking_key in turns[turn_idx]:
empty_turn[thinking_key] = turns[turn_idx][thinking_key]
# Create conversation versions
turns_with_empty = turns[:turn_idx] + [empty_turn]
@@ -697,6 +762,94 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return start_idx, end_idx
@staticmethod
def _convert_content_parts(
content,
) -> tuple[str, list[dict] | None] | None:
"""Convert list content to concatenated string + optional training_detail.
When content is a list of dicts (content parts), each part can specify:
- ``text``, ``content``, or ``value``: the text string
- ``train`` (bool) or ``weight`` (0/1): per-part training flag
Returns ``(concatenated_text, training_details_or_None)`` if content was
a list, or ``None`` if content was not a list (no conversion needed).
.. note::
**Whitespace at part boundaries matters.** BPE tokenizers prepend
spaces to word tokens (e.g. ``" answer"`` is one token). Always
split BEFORE spaces::
GOOD: ["Let me think...", " The answer is 4."]
BAD: ["Let me think... ", "The answer is 4."]
Tokens that straddle a boundary are conservatively masked.
Newlines typically merge with preceding punctuation (``":\\n"`` is
one token), so keep newlines with the preceding part.
"""
if not isinstance(content, list):
return None
text_parts: list[str] = []
training_details: list[dict] = []
has_explicit_training = False
offset = 0
for part in content:
if isinstance(part, dict):
# Extract text (HF uses "text", also support "content"/"value")
text = (
part.get("text") or part.get("content") or part.get("value") or ""
)
text_parts.append(text)
# Check for per-part training flags
part_train = part.get("train")
part_weight = part.get("weight")
if part_train is not None or part_weight is not None:
has_explicit_training = True
train = (
part_train
if part_train is not None
else (part_weight not in (0, 0.0))
)
else:
train = True # default trainable, gated by turn-level should_train
if text:
training_details.append(
{
"begin_offset": offset,
"end_offset": offset + len(text) - 1,
"train": train,
}
)
offset += len(text)
# Warn about trailing whitespace at boundaries between parts with
# different training flags — this almost always causes token straddling
if has_explicit_training and len(training_details) > 1:
for i in range(len(training_details) - 1):
cur = training_details[i]
nxt = training_details[i + 1]
if cur["train"] != nxt["train"]:
boundary_text = text_parts[i]
if boundary_text and boundary_text[-1] in (" ", "\t"):
LOG.warning(
"Content part %d ends with whitespace at a train/mask boundary. "
"BPE tokenizers typically prepend spaces to word tokens, so "
"the space will merge with the next part's first word and the "
"resulting token will be MASKED (not trained). Move the "
"whitespace to the start of the next content part instead. "
"Part text: %r",
i,
boundary_text[-20:],
)
concatenated = "".join(text_parts)
details = training_details if has_explicit_training else None
return concatenated, details
def get_conversation_thread(self, prompt):
turns = []
@@ -723,6 +876,23 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if training_detail is not None:
turn["training_detail"] = training_detail
# Convert list content/reasoning_content to string + auto-generated
# training_detail. See _convert_content_parts for whitespace guidance.
content_result = self._convert_content_parts(turn.get("content"))
if content_result is not None:
turn["content"] = content_result[0]
if content_result[1] is not None:
turn["training_detail"] = content_result[1]
# Also convert reasoning_content (template_thinking_key) if it's a list
thinking_key = self.prompter.template_thinking_key
if thinking_key and thinking_key in turn:
reasoning_result = self._convert_content_parts(turn[thinking_key])
if reasoning_result is not None:
turn[thinking_key] = reasoning_result[0]
if reasoning_result[1] is not None:
turn["reasoning_training_detail"] = reasoning_result[1]
turns.append(turn)
if self.prompter.drop_system_message and turns[0]["role"] == "system":

View File

@@ -36,7 +36,7 @@ from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.freeze import freeze_layers_except, freeze_mm_modules
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.train import determine_last_checkpoint
@@ -114,6 +114,10 @@ def setup_model_and_tokenizer(
):
model.enable_input_require_grads()
# Freeze multimodal modules for text-only training of multimodal models
if cfg.freeze_mm_modules:
freeze_mm_modules(model)
return model, tokenizer, peft_config, processor
@@ -225,6 +229,28 @@ def execute_training(
PLUGIN_MANAGER.post_train(cfg, trainer.model)
def _rename_fsdp_merged_to_adapter(merged_dir: Path):
"""Rename model*.safetensors files to adapter_model* in place.
Also rewrites the index JSON weight_map if sharded output was produced.
"""
for file in sorted(merged_dir.iterdir()):
if file.name.startswith("model") and ".safetensors" in file.name:
file.rename(merged_dir / file.name.replace("model", "adapter_model", 1))
index = merged_dir / "adapter_model.safetensors.index.json"
if index.exists():
data = json.loads(index.read_text(encoding="utf-8"))
if "weight_map" in data:
data["weight_map"] = {
k: v.replace("model", "adapter_model", 1)
for k, v in data["weight_map"].items()
}
index.write_text(
json.dumps(data, indent=2, sort_keys=True) + "\n", encoding="utf-8"
)
def save_trained_model(
cfg: DictDefault,
trainer: Any,
@@ -294,12 +320,17 @@ def save_trained_model(
)
trainer.accelerator.wait_for_everyone()
if trainer.accelerator.is_main_process:
# move all files in merged_path to cfg.output_dir
# FSDP checkpoints for PEFT only contain adapter weights;
# rename model* → adapter_model* so it loads correctly.
is_peft = cfg.adapter and not cfg.relora
if is_peft:
_rename_fsdp_merged_to_adapter(Path(merged_path))
for merged_file in Path(merged_path).iterdir():
if (Path(cfg.output_dir) / merged_file.name).exists():
(Path(cfg.output_dir) / merged_file.name).unlink()
shutil.move(str(merged_file), cfg.output_dir)
shutil.rmtree(merged_path) # remove what should be an empty dir
dest = Path(cfg.output_dir) / merged_file.name
if dest.exists():
dest.unlink()
shutil.move(str(merged_file), dest)
shutil.rmtree(merged_path)
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
# cleanup the FSDP prefix in the model config.json
if trainer.accelerator.is_main_process:

View File

@@ -98,6 +98,56 @@ class SaveModelOnFirstStepCallback(TrainerCallback):
return control
class SkipEvalOnResumeCallback(TrainerCallback):
"""Skip the redundant evaluation that fires when resuming from a checkpoint
whose step aligns with ``eval_steps``.
When HuggingFace Trainer resumes, it restores ``global_step`` from the
checkpoint and immediately triggers ``_maybe_log_save_evaluate`` for that
step. Because the evaluation was already performed during the original
run, repeating it wastes time and pollutes metric logs.
This callback records the ``global_step`` at the start of training (i.e.
the checkpoint step when resuming, or 0 for a fresh run) and suppresses
any evaluation request on that exact step.
"""
def __init__(self):
super().__init__()
self._resume_step: int | None = None
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
):
# ``global_step`` is already restored from the checkpoint at this
# point. For a fresh run it will be 0, so the guard below becomes a
# no-op.
self._resume_step = state.global_step
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
) -> TrainerControl:
if (
self._resume_step
and state.global_step <= self._resume_step
and control.should_evaluate
):
LOG.info(
"Skipping evaluation at step %d (already completed before resume)",
state.global_step,
)
control.should_evaluate = False
return control
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [

View File

@@ -1,7 +1,19 @@
{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0].role == 'system' %}
{{- messages[0].content + '\n\n' }}
{%- if messages[0].content is string %}
{{- messages[0].content + '\n\n' }}
{%- else %}
{%- for part in messages[0].content %}
{%- if part is mapping %}
{%- set system_text = part.get('text') or part.get('content') or part.get('value') %}
{%- if system_text %}{{- system_text }}{%- endif %}
{%- elif part is string %}
{{- part }}
{%- endif %}
{%- endfor %}
{{- '\n\n' }}
{%- endif %}
{%- endif %}
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
@@ -11,7 +23,20 @@
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- if messages[0].content is string %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- else %}
{{- '<|im_start|>system\n' }}
{%- for part in messages[0].content %}
{%- if part is mapping %}
{%- set system_text = part.get('text') or part.get('content') or part.get('value') %}
{%- if system_text %}{{- system_text }}{%- endif %}
{%- elif part is string %}
{{- part }}
{%- endif %}
{%- endfor %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}

View File

@@ -268,6 +268,37 @@ def normalize_config(cfg):
):
cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}
# Gemma4 requires use_reentrant=False for DDP (shared per-layer norms cause
# "marked ready twice" errors with reentrant checkpointing) and
# ddp_find_unused_parameters=True (per_layer_projection LoRA params may not
# receive gradients on every step).
if cfg.model_config_type == "gemma4":
if cfg.gradient_checkpointing:
if cfg.gradient_checkpointing_kwargs is None:
cfg.gradient_checkpointing_kwargs = {}
if cfg.gradient_checkpointing_kwargs.get("use_reentrant") is not False:
LOG.warning(
"Gemma4 requires use_reentrant=False for gradient checkpointing "
"in distributed training. Setting use_reentrant=False."
)
cfg.gradient_checkpointing_kwargs["use_reentrant"] = False
if cfg.ddp and cfg.ddp_find_unused_parameters is None:
if cfg.activation_offloading is True:
# activation_offloading uses checkpoint wrappers that conflict
# with find_unused_parameters (causes "marked ready twice").
# Use freeze_mm_modules instead to eliminate unused params.
LOG.info(
"Gemma4 + DDP + activation_offloading: skipping "
"ddp_find_unused_parameters (use freeze_mm_modules to "
"handle unused vision/audio params)."
)
else:
LOG.warning(
"Gemma4 requires ddp_find_unused_parameters=True for DDP. "
"Auto-enabling."
)
cfg.ddp_find_unused_parameters = True
log_gpu_memory_usage(LOG, "baseline", cfg.device)

View File

@@ -180,6 +180,119 @@ def _drop_long_sequences(
raise ValueError("Unknown RL type")
def _raise_on_long_sequences(
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
) -> bool:
"""Check sequence length and raise ValueError if exceeded.
Used as a filter function for ``excess_length_strategy: raise``.
Args:
sample: Dataset sample to check.
rl: Reinforcement learning type.
tokenizer: Tokenizer for length calculation.
sequence_len: Maximum allowed sequence length.
Returns:
Always True (raises before returning False).
Raises:
ValueError: If any sample exceeds the configured sequence length.
"""
is_valid = _drop_long_sequences(sample, rl, tokenizer, sequence_len)
if not is_valid:
raise ValueError(
f"Sample exceeds configured sequence_len ({sequence_len}). "
"Set `excess_length_strategy: drop` or `excess_length_strategy: truncate` "
"to handle long sequences automatically."
)
return True
def _truncate_long_sequences_rl(
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
) -> dict[str, Any]:
"""Truncate RL samples that exceed maximum sequence length.
For preference datasets (DPO/IPO/ORPO/SIMPO), truncates chosen and rejected
responses to fit within ``sequence_len`` when combined with the prompt.
For KTO, truncates the completion similarly.
GRPO/GDPO/EBFT samples are returned unchanged.
Samples where the prompt alone exceeds ``sequence_len`` cannot be
meaningfully truncated and are returned unchanged. The caller should
follow up with a drop filter to remove them.
Args:
sample: Dataset sample to potentially truncate.
rl: Reinforcement learning type.
tokenizer: Tokenizer for encoding/decoding.
sequence_len: Maximum allowed sequence length.
Returns:
The sample with text fields truncated to fit within sequence_len.
"""
# Fast path: if sample already fits, return unchanged (avoids decode overhead)
if _drop_long_sequences(sample, rl, tokenizer, sequence_len):
return sample
if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}:
if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
):
raise ValueError(
"Prompt, chosen and rejected keys are required for DPO/ORPO datasets"
)
prompt_ids = tokenizer(sample["prompt"], add_special_tokens=False)["input_ids"]
chosen_ids = tokenizer(sample["chosen"], add_special_tokens=False)["input_ids"]
rejected_ids = tokenizer(sample["rejected"], add_special_tokens=False)[
"input_ids"
]
max_response_len = sequence_len - len(prompt_ids)
if max_response_len <= 0:
# Prompt alone exceeds limit; cannot meaningfully truncate.
# Returned unchanged — the follow-up drop filter will remove it.
return sample
updates: dict[str, Any] = {}
if len(chosen_ids) > max_response_len:
updates["chosen"] = tokenizer.decode(
chosen_ids[:max_response_len], skip_special_tokens=False
)
if len(rejected_ids) > max_response_len:
updates["rejected"] = tokenizer.decode(
rejected_ids[:max_response_len], skip_special_tokens=False
)
if updates:
sample = {**sample, **updates}
elif rl is RLType.KTO:
if not (sample.get("prompt") and sample.get("completion")):
raise ValueError("Prompt and completion keys are required for KTO datasets")
prompt_ids = tokenizer(sample["prompt"], add_special_tokens=False)["input_ids"]
completion_ids = tokenizer(sample["completion"], add_special_tokens=False)[
"input_ids"
]
max_completion_len = sequence_len - len(prompt_ids)
if max_completion_len <= 0:
return sample
if len(completion_ids) > max_completion_len:
sample = {
**sample,
"completion": tokenizer.decode(
completion_ids[:max_completion_len], skip_special_tokens=False
),
}
# GRPO/GDPO/EBFT: no truncation needed (responses generated at runtime)
return sample
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
"""Load and process dataset split for RL training.
@@ -243,23 +356,77 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
split_datasets[i] = dataset
if not cfg.skip_prepare_dataset:
drop_long = partial(
_drop_long_sequences,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
if excess_length_strategy == "truncate":
truncate_fn = partial(
_truncate_long_sequences_rl,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].map(
truncate_fn,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Truncating Long Sequences",
)
# Drop samples that could not be truncated (e.g. prompt
# alone exceeds sequence_len)
drop_long = partial(
_drop_long_sequences,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Un-truncatable Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(
f"Dropped {dropped} samples from dataset index {i} "
f"that could not be truncated to fit sequence_len "
f"(prompt alone exceeds limit)"
)
elif excess_length_strategy == "raise":
raise_fn = partial(
_raise_on_long_sequences,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
split_datasets[i] = split_datasets[i].filter(
raise_fn,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Checking Sequence Lengths",
)
else: # "drop" (default)
drop_long = partial(
_drop_long_sequences,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(
f"Dropped {dropped} long samples from dataset index {i}"
)
# Merge datasets
dataset = merge_datasets(split_datasets, cfg)

View File

@@ -10,6 +10,44 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# Top-level module name prefixes that belong to vision/audio/multimodal encoders
# rather than the language backbone. These are matched against the first component
# of each ``named_parameter`` path (e.g. "model.vision_tower." -> "vision_tower").
_MM_MODULE_PREFIXES = (
"vision_tower",
"vision_model",
"vision_encoder",
"embed_vision",
"multi_modal_projector",
"visual",
"audio_tower",
"audio_model",
"embed_audio",
)
def freeze_mm_modules(model):
"""Freeze all vision/audio/multimodal-projector parameters.
Iterates over ``model.named_parameters()`` and sets ``requires_grad = False``
for any parameter whose name contains a known vision/audio module prefix.
This is useful when fine-tuning only the language backbone of a multimodal
model and avoids the need for ``ddp_find_unused_parameters=True``.
"""
frozen_count = 0
for name, param in model.named_parameters():
# Check if any path component matches a vision/audio prefix
parts = name.split(".")
if any(part in _MM_MODULE_PREFIXES for part in parts):
if param.requires_grad:
param.requires_grad = False
frozen_count += 1
if is_main_process():
LOG.debug(f"freeze_mm_modules: froze {name}")
if is_main_process():
LOG.info(f"freeze_mm_modules: froze {frozen_count} vision/audio parameters")
def freeze_layers_except(model, regex_patterns):
"""

View File

@@ -578,6 +578,17 @@ class AxolotlInputConfig(
},
)
freeze_mm_modules: bool | None = Field(
default=None,
json_schema_extra={
"description": "Freeze multimodal encoder parameters (vision, audio, etc.) for "
"text-only training of multimodal models. When True, parameters belonging to "
"vision towers, audio towers, multimodal projectors, and similar non-language "
"modules are frozen (requires_grad=False). This allows DDP training without "
"ddp_find_unused_parameters=True."
},
)
unfrozen_parameters: list[str] | None = Field(
default=None,
json_schema_extra={
@@ -766,6 +777,15 @@ class AxolotlInputConfig(
},
)
gemma4_hybrid_attn_impl: bool | None = Field(
default=None,
json_schema_extra={
"description": "Use hybrid attention for Gemma 4: flash_attention_2 for sliding window layers "
"and sdpa for global (full_attention) layers. Global layers have head_dim=512 which "
"exceeds flash attention's supported size."
},
)
experts_implementation: str | None = Field(
default=None,
json_schema_extra={

View File

@@ -87,9 +87,11 @@ class ModelInputConfig(BaseModel):
json_schema_extra={"description": "Use custom kernels, e.g. MegaBlocks."},
)
model_quantization_config: Literal["Mxfp4Config"] | None = Field(
default=None,
json_schema_extra={"description": "Model loading quantization config"},
model_quantization_config: Literal["Mxfp4Config", "FineGrainedFP8Config"] | None = (
Field(
default=None,
json_schema_extra={"description": "Model loading quantization config"},
)
)
model_quantization_config_kwargs: dict[str, Any] | None = Field(
default=None,

View File

@@ -422,22 +422,22 @@ class TestPeftLoRAWeightExtraction:
)
# gate_up_proj [E, 2*inter, hidden]
# peft: in_features=2*inter (dim 1), out_features=hidden (dim 2)
# peft: in_features=hidden (last dim), out_features=2*inter (middle dim)
assert trainable[
"base_model.model.moe.experts.base_layer.lora_A.default.weight"
].shape == (E * r, 2 * config.intermediate_size)
assert trainable[
"base_model.model.moe.experts.base_layer.lora_B.default.weight"
].shape == (config.hidden_size, E * r)
# down_proj [E, hidden, inter]
# peft: in_features=hidden (dim 1), out_features=inter (dim 2)
assert trainable[
"base_model.model.moe.experts.lora_A.default.weight"
].shape == (E * r, config.hidden_size)
assert trainable[
"base_model.model.moe.experts.base_layer.lora_B.default.weight"
].shape == (2 * config.intermediate_size, E * r)
# down_proj [E, hidden, inter]
# peft: in_features=inter (last dim), out_features=hidden (middle dim)
assert trainable[
"base_model.model.moe.experts.lora_A.default.weight"
].shape == (E * r, config.intermediate_size)
assert trainable[
"base_model.model.moe.experts.lora_B.default.weight"
].shape == (config.intermediate_size, E * r)
].shape == (config.hidden_size, E * r)
@requires_cuda
def test_peft_forward_runs(self):
@@ -489,26 +489,28 @@ class TestPeftLoRAWeightExtraction:
assert down_lora is not None, "down_proj LoRA not detected"
# Check shapes (after peft->scattermoe conversion with A<->B swap)
# gate_up_proj W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter
# gate_up_proj: peft A [E*r, hidden] / B [2*inter, E*r]
# After swap: smoe_A [E*r, 2*inter], smoe_B [hidden, E*r]
E, r = config.num_experts, 4
gup_A, gup_B, gup_s = gup_lora
assert gup_A.shape == (E * r, config.hidden_size), (
f"gate_up_proj smoe_A: expected [r*E, K=hidden]={(E * r, config.hidden_size)}, "
assert gup_A.shape == (E * r, 2 * config.intermediate_size), (
f"gate_up_proj smoe_A: expected [r*E, 2*inter]={(E * r, 2 * config.intermediate_size)}, "
f"got {gup_A.shape}"
)
assert gup_B.shape == (2 * config.intermediate_size, E * r), (
f"gate_up_proj smoe_B: expected [N=2*inter, r*E]="
f"{(2 * config.intermediate_size, E * r)}, got {gup_B.shape}"
assert gup_B.shape == (config.hidden_size, E * r), (
f"gate_up_proj smoe_B: expected [hidden, r*E]="
f"{(config.hidden_size, E * r)}, got {gup_B.shape}"
)
# down_proj W = param.T = [E, inter, hidden], K=inter, N=hidden
# down_proj: peft A [E*r, inter] / B [hidden, E*r]
# After swap: smoe_A [E*r, hidden], smoe_B [inter, E*r]
down_A, down_B, down_s = down_lora
assert down_A.shape == (E * r, config.intermediate_size), (
f"down_proj smoe_A: expected [r*E, K=inter]={(E * r, config.intermediate_size)}, "
assert down_A.shape == (E * r, config.hidden_size), (
f"down_proj smoe_A: expected [r*E, hidden]={(E * r, config.hidden_size)}, "
f"got {down_A.shape}"
)
assert down_B.shape == (config.hidden_size, E * r), (
f"down_proj smoe_B: expected [N=hidden, r*E]={(config.hidden_size, E * r)}, "
assert down_B.shape == (config.intermediate_size, E * r), (
f"down_proj smoe_B: expected [inter, r*E]={(config.intermediate_size, E * r)}, "
f"got {down_B.shape}"
)

View File

@@ -0,0 +1,226 @@
"""
Correctness tests for the fused RMSNorm+RoPE Triton kernel.
Tests forward and backward against the reference Gemma4 implementation
(Gemma4RMSNorm + apply_rotary_pos_emb) across both sliding window
(head_dim=256) and global attention (head_dim=512) layer configurations.
"""
import pytest
import torch
torch.manual_seed(42)
# Skip entire module if no CUDA
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
def _reference_norm_rope(x, weight, cos, sin, eps):
"""Reference: separate Gemma4RMSNorm + apply_rotary_pos_emb."""
from transformers.models.gemma4.modeling_gemma4 import (
Gemma4RMSNorm,
apply_rotary_pos_emb,
)
D = x.shape[-1]
norm = Gemma4RMSNorm(D, eps=eps).to(x.device, x.dtype)
norm.weight.data.copy_(weight)
normed = norm(x)
return apply_rotary_pos_emb(normed, cos, sin, unsqueeze_dim=2)
def _reference_norm_noscale(x, eps):
"""Reference: Gemma4RMSNorm with_scale=False."""
from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm
D = x.shape[-1]
norm = Gemma4RMSNorm(D, eps=eps, with_scale=False).to(x.device, x.dtype)
return norm(x)
@pytest.fixture(
params=[
(2, 64, 32, 256), # sliding window layer shape
(2, 64, 4, 512), # global attention layer shape
(1, 128, 16, 256), # different batch/seq
(1, 1, 1, 8), # minimal size
],
ids=["sliding_256", "global_512", "varied", "minimal"],
)
def shapes(request):
return request.param
@pytest.fixture(params=[torch.bfloat16, torch.float16], ids=["bf16", "fp16"])
def dtype(request):
return request.param
class TestFusedRMSNormRoPEForward:
"""Forward pass correctness."""
def test_matches_reference(self, shapes, dtype):
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
B, S, H, D = shapes
eps = 1e-6
x = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
weight = torch.randn(D, device="cuda", dtype=dtype)
cos = torch.randn(B, S, D, device="cuda", dtype=dtype)
sin = torch.randn(B, S, D, device="cuda", dtype=dtype)
y_ref = _reference_norm_rope(x.clone(), weight, cos, sin, eps)
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
cos_sim = torch.nn.functional.cosine_similarity(
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
)
assert cos_sim > 0.999, f"Forward cosine_sim={cos_sim:.6f}, expected > 0.999"
def test_output_shape(self, shapes):
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
B, S, H, D = shapes
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
y = fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6)
assert y.shape == x.shape
assert y.dtype == x.dtype
class TestFusedRMSNormRoPEBackward:
"""Backward pass correctness via gradient comparison."""
@pytest.mark.parametrize(
"B,S,H,D",
[(2, 64, 32, 256), (2, 64, 4, 512)],
ids=["sliding_256", "global_512"],
)
def test_x_grad_matches_reference(self, B, S, H, D):
from transformers.models.gemma4.modeling_gemma4 import (
Gemma4RMSNorm,
apply_rotary_pos_emb,
)
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
eps = 1e-6
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
# Reference backward
x_ref = torch.randn(
B, S, H, D, device="cuda", dtype=torch.bfloat16, requires_grad=True
)
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
norm_ref.weight.data.copy_(weight_init)
y_ref = apply_rotary_pos_emb(norm_ref(x_ref), cos, sin, unsqueeze_dim=2)
y_ref.sum().backward()
# Fused backward
x_fused = x_ref.data.clone().requires_grad_(True)
w_fused = weight_init.clone().requires_grad_(True)
y_fused = fused_rms_norm_rope(x_fused, w_fused, cos, sin, eps=eps)
y_fused.sum().backward()
cos_sim_x = torch.nn.functional.cosine_similarity(
x_fused.grad.flatten().float(), x_ref.grad.flatten().float(), dim=0
)
assert cos_sim_x > 0.999, f"x grad cosine_sim={cos_sim_x:.6f}, expected > 0.999"
@pytest.mark.parametrize(
"B,S,H,D",
[(2, 64, 32, 256), (2, 64, 4, 512)],
ids=["sliding_256", "global_512"],
)
def test_weight_grad_matches_reference(self, B, S, H, D):
from transformers.models.gemma4.modeling_gemma4 import (
Gemma4RMSNorm,
apply_rotary_pos_emb,
)
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
eps = 1e-6
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
# Reference
x_ref = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
norm_ref.weight = torch.nn.Parameter(weight_init.clone())
apply_rotary_pos_emb(
norm_ref(x_ref), cos, sin, unsqueeze_dim=2
).sum().backward()
# Fused
w_fused = weight_init.clone().requires_grad_(True)
fused_rms_norm_rope(x_ref.clone(), w_fused, cos, sin, eps=eps).sum().backward()
cos_sim_w = torch.nn.functional.cosine_similarity(
w_fused.grad.flatten().float(),
norm_ref.weight.grad.flatten().float(),
dim=0,
)
assert cos_sim_w > 0.995, (
f"weight grad cosine_sim={cos_sim_w:.6f}, expected > 0.995"
)
def test_grad_flows(self):
"""Verify gradients are non-zero and finite."""
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
B, S, H, D = 1, 16, 4, 64
x = torch.randn(
B, S, H, D, device="cuda", dtype=torch.bfloat16, requires_grad=True
)
w = torch.randn(D, device="cuda", dtype=torch.bfloat16, requires_grad=True)
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
y = fused_rms_norm_rope(x, w, cos, sin, eps=1e-6)
y.sum().backward()
assert x.grad is not None, "x.grad is None"
assert w.grad is not None, "w.grad is None"
assert x.grad.isfinite().all(), "x.grad has non-finite values"
assert w.grad.isfinite().all(), "w.grad has non-finite values"
assert x.grad.abs().sum() > 0, "x.grad is all zeros"
assert w.grad.abs().sum() > 0, "w.grad is all zeros"
class TestFusedRMSNormNoScale:
"""Tests for v_norm (RMSNorm without learnable scale)."""
def test_forward_matches_reference(self, shapes, dtype):
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_noscale
B, S, H, D = shapes
eps = 1e-6
x = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
y_ref = _reference_norm_noscale(x.clone(), eps)
y_fused = fused_rms_norm_noscale(x.clone(), eps=eps)
cos_sim = torch.nn.functional.cosine_similarity(
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
)
assert cos_sim > 0.999, f"v_norm cosine_sim={cos_sim:.6f}, expected > 0.999"
def test_backward_flows(self):
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_noscale
x = torch.randn(
1, 16, 4, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True
)
y = fused_rms_norm_noscale(x, eps=1e-6)
y.sum().backward()
assert x.grad is not None
assert x.grad.isfinite().all()
assert x.grad.abs().sum() > 0

View File

@@ -916,6 +916,235 @@ class TestChatTemplateConfigurations:
LOG.debug(f"Final labels: {labels}")
LOG.debug(f"Final input_ids: {input_ids}")
@enable_hf_offline
def test_content_parts_training(
self,
tokenizer,
chat_template,
chat_template_jinja,
eos_token,
request,
):
LOG.info("Testing with content as list of parts with per-part training")
tokenizer, chat_template_jinja = self.setup_tokenizer(
tokenizer, chat_template, chat_template_jinja, eos_token, request
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
# Dataset where assistant content is a list of parts with per-part training
conversation = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are an AI assistant."},
],
},
{
"role": "user",
"content": [
{"type": "text", "text": "What is 2+2?"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me think...", "train": False},
{"type": "text", "text": "The answer is 4.", "train": True},
],
},
]
dataset = Dataset.from_dict({"messages": [conversation]})
res = strategy.tokenize_prompt(dataset[0])
turns = strategy.get_conversation_thread(dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Find the assistant turn (last turn)
assistant_turn_idx = len(turns) - 1
start_idx, end_idx = strategy.find_turn(
turns=turns, turn_idx=assistant_turn_idx
)
assert start_idx != -1 and end_idx != -1, (
"Could not find assistant turn boundaries"
)
decoded = tokenizer.decode(input_ids[start_idx:end_idx])
LOG.debug(f"Assistant turn decoded: {decoded}")
# Tokenize each part separately to find their boundaries
part1_text = "Let me think..."
part2_text = "The answer is 4."
# Verify the concatenated content is in the decoded output
assert part1_text in decoded, (
f"Part 1 '{part1_text}' not found in decoded: {decoded}"
)
assert part2_text in decoded, (
f"Part 2 '{part2_text}' not found in decoded: {decoded}"
)
# Verify that part1 tokens (train=False) are masked
# and part2 tokens (train=True) are labeled
turn_labels = labels[start_idx:end_idx]
# Find where part2 starts in the token sequence
part1_tokens = tokenizer(part1_text, add_special_tokens=False)["input_ids"]
part2_tokens = tokenizer(part2_text, add_special_tokens=False)["input_ids"]
# The first part should be masked (all IGNORE_TOKEN_ID)
# Due to token boundary alignment, check that at least the interior tokens
# of part1 are masked
assert any(label == IGNORE_TOKEN_ID for label in turn_labels), (
f"Expected some masked labels for train=False part, but got {turn_labels}"
)
# The second part should be trained (not IGNORE_TOKEN_ID)
assert any(label != IGNORE_TOKEN_ID for label in turn_labels), (
f"Expected some trained labels for train=True part, but got {turn_labels}"
)
# More precise check: first N tokens should be masked, last M tokens should be trained
# where N ~ len(part1_tokens) and M ~ len(part2_tokens)
# Allow for token boundary effects at the boundary
num_masked = sum(1 for label in turn_labels if label == IGNORE_TOKEN_ID)
num_trained = sum(1 for label in turn_labels if label != IGNORE_TOKEN_ID)
LOG.debug(f"Turn labels: {turn_labels}")
LOG.debug(f"Masked tokens: {num_masked}, Trained tokens: {num_trained}")
LOG.debug(
f"Part1 tokens: {len(part1_tokens)}, Part2 tokens: {len(part2_tokens)}"
)
# The number of masked tokens should be roughly the size of part1
# and the number of trained tokens should be roughly the size of part2
assert num_masked > 0, "Expected masked tokens for the train=False part"
assert num_trained > 0, "Expected trained tokens for the train=True part"
@enable_hf_offline
def test_content_parts_with_weight(
self,
tokenizer,
chat_template,
chat_template_jinja,
eos_token,
request,
):
LOG.info("Testing with content parts using weight field")
tokenizer, chat_template_jinja = self.setup_tokenizer(
tokenizer, chat_template, chat_template_jinja, eos_token, request
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
# Dataset using weight instead of train
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Thinking step by step: ", "weight": 0},
{"type": "text", "text": "Hello! How can I help?", "weight": 1},
],
},
]
dataset = Dataset.from_dict({"messages": [conversation]})
res = strategy.tokenize_prompt(dataset[0])
labels = res["labels"]
# There should be both masked and trained labels
has_masked = any(label == IGNORE_TOKEN_ID for label in labels)
has_trained = any(label != IGNORE_TOKEN_ID for label in labels)
assert has_masked, "Expected masked tokens (weight=0 part + user turn)"
assert has_trained, "Expected trained tokens (weight=1 part)"
@enable_hf_offline
def test_content_parts_string_passthrough(
self,
tokenizer,
chat_template,
chat_template_jinja,
eos_token,
request,
):
LOG.info("Testing that string content still works alongside list content")
tokenizer, chat_template_jinja = self.setup_tokenizer(
tokenizer, chat_template, chat_template_jinja, eos_token, request
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
# All list content in the conversation
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is 2+2?"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "The answer is 4.", "train": True},
],
},
]
dataset = Dataset.from_dict({"messages": [conversation]})
res = strategy.tokenize_prompt(dataset[0])
# Should tokenize without errors
assert "input_ids" in res
assert "labels" in res
assert len(res["input_ids"]) > 0
def test_get_chat_template_variables(
self, tokenizer, chat_template, chat_template_jinja, eos_token, request
):
@@ -1428,3 +1657,250 @@ class TestChatTemplateToolCalling:
assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
f"Assistant turn {i} should be unmasked"
)
class TestChatTemplateReasoningContent:
"""
Test class for reasoning_content with content parts.
"""
@enable_hf_offline
def test_reasoning_content_with_content_parts(self, qwen3_tokenizer):
"""Test that reasoning_content as string + content as list parts works correctly.
Content training_detail offsets should align with content-only boundaries."""
LOG.info("Testing reasoning_content with content parts on qwen3")
tokenizer = deepcopy(qwen3_tokenizer)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template("qwen3"),
message_property_mappings={
"role": "role",
"content": "content",
"reasoning_content": "reasoning_content",
},
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
# reasoning_content is a plain string, content is list with per-part training
conversation = [
{
"role": "user",
"content": [{"type": "text", "text": "What is 2+2?"}],
},
{
"role": "assistant",
"reasoning_content": "Step 1: 2+2=4",
"content": [
{"type": "text", "text": "The answer is 4.", "train": True},
],
},
]
dataset = Dataset.from_dict({"messages": [conversation]})
res = strategy.tokenize_prompt(dataset[0])
turns = strategy.get_conversation_thread(dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Find the assistant turn
assistant_idx = 1
start_idx, end_idx = strategy.find_turn(
turns=turns, turn_idx=assistant_idx, content_only=True
)
assert start_idx != -1 and end_idx != -1, (
"Could not find assistant content boundaries"
)
# The content-only span should contain "The answer is 4." but NOT "Step 1: 2+2=4"
decoded_span = tokenizer.decode(input_ids[start_idx:end_idx])
assert "The answer is 4." in decoded_span, (
f"Content not found in span: {decoded_span}"
)
assert "Step 1" not in decoded_span, (
f"Reasoning should not be in content-only span: {decoded_span}"
)
# Verify that content tokens are trained
content_labels = labels[start_idx:end_idx]
assert any(label != IGNORE_TOKEN_ID for label in content_labels), (
f"Expected trained labels in content span, got {content_labels}"
)
@enable_hf_offline
def test_reasoning_content_per_part_masking(self, qwen3_tokenizer):
"""Test masking incorrect reasoning while training on self-correction.
This is the core use case: mask out wrong thoughts, train on corrections."""
LOG.info("Testing reasoning_content per-part masking on qwen3")
tokenizer = deepcopy(qwen3_tokenizer)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template("qwen3"),
message_property_mappings={
"role": "role",
"content": "content",
"reasoning_content": "reasoning_content",
},
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
# Reasoning has wrong step (masked) then self-correction (trained)
conversation = [
{
"role": "user",
"content": [{"type": "text", "text": "What is 2+2?"}],
},
{
"role": "assistant",
"reasoning_content": [
{"type": "text", "text": "Hmm maybe 2+2=5.", "train": False},
{"type": "text", "text": " Wait no, 2+2=4.", "train": True},
],
"content": [
{"type": "text", "text": "The answer is 4.", "train": True},
],
},
]
dataset = Dataset.from_dict({"messages": [conversation]})
res = strategy.tokenize_prompt(dataset[0])
turns = strategy.get_conversation_thread(dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Find reasoning boundaries
reasoning_start, reasoning_end = strategy.find_turn(
turns=turns, turn_idx=1, reasoning_only=True
)
assert reasoning_start != -1 and reasoning_end != -1, (
"Could not find reasoning boundaries"
)
decoded_reasoning = tokenizer.decode(input_ids[reasoning_start:reasoning_end])
LOG.debug(f"Reasoning span: {decoded_reasoning!r}")
assert "2+2=5" in decoded_reasoning, (
f"Wrong step not in reasoning span: {decoded_reasoning}"
)
assert "2+2=4" in decoded_reasoning, (
f"Correction not in reasoning span: {decoded_reasoning}"
)
# Verify reasoning labels have both masked and trained tokens
reasoning_labels = labels[reasoning_start:reasoning_end]
reasoning_ids = input_ids[reasoning_start:reasoning_end]
# Decode only the trained tokens — should be exactly the self-correction
trained_ids = [
tid
for tid, lab in zip(reasoning_ids, reasoning_labels, strict=True)
if lab != IGNORE_TOKEN_ID
]
trained_text = tokenizer.decode(trained_ids)
assert trained_text.strip() == "Wait no, 2+2=4.", (
f"Expected trained reasoning to be 'Wait no, 2+2=4.', got: {trained_text!r}"
)
# Decode only the masked tokens — should be exactly the incorrect step
masked_ids = [
tid
for tid, lab in zip(reasoning_ids, reasoning_labels, strict=True)
if lab == IGNORE_TOKEN_ID
]
masked_text = tokenizer.decode(masked_ids)
assert masked_text.strip() == "Hmm maybe 2+2=5.", (
f"Expected masked reasoning to be 'Hmm maybe 2+2=5.', got: {masked_text!r}"
)
# Find content boundaries
content_start, content_end = strategy.find_turn(
turns=turns, turn_idx=1, content_only=True
)
assert content_start != -1 and content_end != -1, (
"Could not find content boundaries"
)
# Content should be fully trained — decode trained tokens to verify
content_labels = labels[content_start:content_end]
content_ids = input_ids[content_start:content_end]
content_trained_ids = [
tid
for tid, lab in zip(content_ids, content_labels, strict=True)
if lab != IGNORE_TOKEN_ID
]
content_trained_text = tokenizer.decode(content_trained_ids)
assert "The answer is 4." in content_trained_text, (
f"Expected 'The answer is 4.' in trained content tokens, "
f"got: {content_trained_text!r}"
)
assert all(label != IGNORE_TOKEN_ID for label in content_labels), (
f"Expected all content labels trained, got {content_labels}"
)
@enable_hf_offline
def test_reasoning_content_as_list_no_training_flags(self, qwen3_tokenizer):
"""Test that reasoning_content as list without training flags still works."""
LOG.info("Testing reasoning_content as list without training flags on qwen3")
tokenizer = deepcopy(qwen3_tokenizer)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template("qwen3"),
message_property_mappings={
"role": "role",
"content": "content",
"reasoning_content": "reasoning_content",
},
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
# Both as lists, no per-part training flags
conversation = [
{
"role": "user",
"content": [{"type": "text", "text": "What is 2+2?"}],
},
{
"role": "assistant",
"reasoning_content": [
{"type": "text", "text": "Step 1: addition."},
{"type": "text", "text": " Step 2: 2+2=4."},
],
"content": [
{"type": "text", "text": "The answer is 4."},
],
},
]
dataset = Dataset.from_dict({"messages": [conversation]})
res = strategy.tokenize_prompt(dataset[0])
# Should tokenize without errors
assert "input_ids" in res
assert "labels" in res
assert len(res["input_ids"]) > 0
# Verify the full output contains both reasoning and content
full_text = tokenizer.decode(res["input_ids"])
assert "Step 1: addition." in full_text
assert "Step 2: 2+2=4." in full_text
assert "The answer is 4." in full_text

View File

@@ -0,0 +1,63 @@
"""Tests for SkipEvalOnResumeCallback."""
from unittest.mock import MagicMock
from transformers import TrainerControl, TrainerState, TrainingArguments
from axolotl.utils.callbacks import SkipEvalOnResumeCallback
class TestSkipEvalOnResumeCallback:
"""Tests for skipping redundant evaluation on checkpoint resume."""
@staticmethod
def _make_state(global_step: int) -> TrainerState:
state = MagicMock(spec=TrainerState)
state.global_step = global_step
return state
def test_suppresses_eval_at_resume_step(self):
cb = SkipEvalOnResumeCallback()
args = MagicMock(spec=TrainingArguments)
state = self._make_state(20)
control = TrainerControl(should_evaluate=False)
# Simulate on_train_begin at checkpoint-20
cb.on_train_begin(args, state, control)
# Trainer sets should_evaluate = True for step 20
control.should_evaluate = True
result = cb.on_step_end(args, state, control)
assert result.should_evaluate is False
def test_allows_eval_after_resume_step(self):
cb = SkipEvalOnResumeCallback()
args = MagicMock(spec=TrainingArguments)
state = self._make_state(20)
control = TrainerControl(should_evaluate=False)
cb.on_train_begin(args, state, control)
# Advance past the resume point
state.global_step = 30
control.should_evaluate = True
result = cb.on_step_end(args, state, control)
assert result.should_evaluate is True
def test_noop_on_fresh_run(self):
cb = SkipEvalOnResumeCallback()
args = MagicMock(spec=TrainingArguments)
state = self._make_state(0)
control = TrainerControl(should_evaluate=False)
# Fresh run: global_step starts at 0
cb.on_train_begin(args, state, control)
# Even if eval triggers at step 0 (unlikely but defensive)
state.global_step = 10
control.should_evaluate = True
result = cb.on_step_end(args, state, control)
assert result.should_evaluate is True

292
tests/utils/data/test_rl.py Normal file
View File

@@ -0,0 +1,292 @@
"""
Unit tests for RL data utility functions (excess_length_strategy support).
"""
import unittest
from axolotl.utils.data.rl import (
_drop_long_sequences,
_raise_on_long_sequences,
_truncate_long_sequences_rl,
)
from axolotl.utils.schemas.enums import RLType
class _FakeTokenizer:
"""Simple whitespace tokenizer for testing length calculations."""
def __call__(self, text, add_special_tokens=True): # noqa: ARG002
tokens = text.split()
return {"input_ids": list(range(len(tokens)))}
def decode(self, token_ids, skip_special_tokens=True): # noqa: ARG002
# Each token id maps to a placeholder word; length is what matters.
return " ".join(f"w{i}" for i in range(len(token_ids)))
def _make_dpo_sample(prompt_len: int, chosen_len: int, rejected_len: int):
"""Create a DPO sample with specified word counts."""
return {
"prompt": " ".join(f"p{i}" for i in range(prompt_len)),
"chosen": " ".join(f"c{i}" for i in range(chosen_len)),
"rejected": " ".join(f"r{i}" for i in range(rejected_len)),
}
def _make_kto_sample(prompt_len: int, completion_len: int):
"""Create a KTO sample with specified word counts."""
return {
"prompt": " ".join(f"p{i}" for i in range(prompt_len)),
"completion": " ".join(f"c{i}" for i in range(completion_len)),
}
class TestDropLongSequences(unittest.TestCase):
"""Tests for the existing _drop_long_sequences filter function."""
def setUp(self):
self.tokenizer = _FakeTokenizer()
def test_dpo_keeps_short_samples(self):
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
result = _drop_long_sequences(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
self.assertTrue(result)
def test_dpo_drops_long_chosen(self):
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=2)
result = _drop_long_sequences(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
self.assertFalse(result)
def test_dpo_drops_long_rejected(self):
sample = _make_dpo_sample(prompt_len=5, chosen_len=2, rejected_len=10)
result = _drop_long_sequences(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
self.assertFalse(result)
def test_kto_keeps_short_samples(self):
sample = _make_kto_sample(prompt_len=3, completion_len=2)
result = _drop_long_sequences(
sample, RLType.KTO, self.tokenizer, sequence_len=10
)
self.assertTrue(result)
def test_kto_drops_long_completion(self):
sample = _make_kto_sample(prompt_len=5, completion_len=10)
result = _drop_long_sequences(
sample, RLType.KTO, self.tokenizer, sequence_len=10
)
self.assertFalse(result)
def test_grpo_always_keeps(self):
sample = {"prompt": "a " * 100}
result = _drop_long_sequences(
sample, RLType.GRPO, self.tokenizer, sequence_len=5
)
self.assertTrue(result)
def test_dpo_missing_keys_raises(self):
with self.assertRaises(ValueError):
_drop_long_sequences({"prompt": "hi"}, RLType.DPO, self.tokenizer, 10)
def test_kto_missing_keys_raises(self):
with self.assertRaises(ValueError):
_drop_long_sequences({"prompt": "hi"}, RLType.KTO, self.tokenizer, 10)
def test_ipo_uses_dpo_logic(self):
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=2)
result = _drop_long_sequences(
sample, RLType.IPO, self.tokenizer, sequence_len=10
)
self.assertFalse(result)
def test_orpo_uses_dpo_logic(self):
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
result = _drop_long_sequences(
sample, RLType.ORPO, self.tokenizer, sequence_len=10
)
self.assertTrue(result)
def test_boundary_length_kept(self):
"""Samples exactly at sequence_len should be kept."""
sample = _make_dpo_sample(prompt_len=5, chosen_len=5, rejected_len=5)
result = _drop_long_sequences(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
self.assertTrue(result)
class TestRaiseOnLongSequences(unittest.TestCase):
"""Tests for _raise_on_long_sequences (excess_length_strategy='raise')."""
def setUp(self):
self.tokenizer = _FakeTokenizer()
def test_short_sample_passes(self):
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
result = _raise_on_long_sequences(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
self.assertTrue(result)
def test_long_sample_raises_valueerror(self):
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=2)
with self.assertRaises(ValueError, msg="excess_length_strategy"):
_raise_on_long_sequences(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
def test_kto_long_raises(self):
sample = _make_kto_sample(prompt_len=5, completion_len=10)
with self.assertRaises(ValueError):
_raise_on_long_sequences(
sample, RLType.KTO, self.tokenizer, sequence_len=10
)
def test_grpo_never_raises(self):
sample = {"prompt": "a " * 100}
result = _raise_on_long_sequences(
sample, RLType.GRPO, self.tokenizer, sequence_len=5
)
self.assertTrue(result)
class TestTruncateLongSequencesRL(unittest.TestCase):
"""Tests for _truncate_long_sequences_rl (excess_length_strategy='truncate')."""
def setUp(self):
self.tokenizer = _FakeTokenizer()
def test_dpo_short_sample_unchanged(self):
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
result = _truncate_long_sequences_rl(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
self.assertEqual(result["chosen"], sample["chosen"])
self.assertEqual(result["rejected"], sample["rejected"])
def test_dpo_truncates_chosen(self):
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=3)
result = _truncate_long_sequences_rl(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
# max_response_len = 10 - 5 = 5, chosen had 10 words -> truncated to 5
chosen_tokens = self.tokenizer(result["chosen"], add_special_tokens=False)[
"input_ids"
]
self.assertEqual(len(chosen_tokens), 5)
def test_dpo_truncates_rejected(self):
sample = _make_dpo_sample(prompt_len=5, chosen_len=3, rejected_len=10)
result = _truncate_long_sequences_rl(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
rejected_tokens = self.tokenizer(result["rejected"], add_special_tokens=False)[
"input_ids"
]
self.assertEqual(len(rejected_tokens), 5)
def test_dpo_truncates_both(self):
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=10)
result = _truncate_long_sequences_rl(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
chosen_len = len(
self.tokenizer(result["chosen"], add_special_tokens=False)["input_ids"]
)
rejected_len = len(
self.tokenizer(result["rejected"], add_special_tokens=False)["input_ids"]
)
self.assertEqual(chosen_len, 5)
self.assertEqual(rejected_len, 5)
def test_dpo_prompt_unchanged(self):
"""Prompt text should never be modified."""
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=10)
result = _truncate_long_sequences_rl(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
self.assertEqual(result["prompt"], sample["prompt"])
def test_dpo_prompt_exceeds_limit_returns_unchanged(self):
"""When prompt alone exceeds sequence_len, sample is returned as-is."""
sample = _make_dpo_sample(prompt_len=15, chosen_len=3, rejected_len=3)
result = _truncate_long_sequences_rl(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
self.assertEqual(result, sample)
def test_kto_truncates_completion(self):
sample = _make_kto_sample(prompt_len=5, completion_len=10)
result = _truncate_long_sequences_rl(
sample, RLType.KTO, self.tokenizer, sequence_len=10
)
completion_len = len(
self.tokenizer(result["completion"], add_special_tokens=False)["input_ids"]
)
self.assertEqual(completion_len, 5)
def test_kto_short_sample_unchanged(self):
sample = _make_kto_sample(prompt_len=3, completion_len=2)
result = _truncate_long_sequences_rl(
sample, RLType.KTO, self.tokenizer, sequence_len=10
)
self.assertEqual(result["completion"], sample["completion"])
def test_kto_prompt_exceeds_limit_returns_unchanged(self):
sample = _make_kto_sample(prompt_len=15, completion_len=3)
result = _truncate_long_sequences_rl(
sample, RLType.KTO, self.tokenizer, sequence_len=10
)
self.assertEqual(result, sample)
def test_grpo_unchanged(self):
sample = {"prompt": "a " * 100}
result = _truncate_long_sequences_rl(
sample, RLType.GRPO, self.tokenizer, sequence_len=5
)
self.assertEqual(result, sample)
def test_ipo_uses_dpo_logic(self):
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=3)
result = _truncate_long_sequences_rl(
sample, RLType.IPO, self.tokenizer, sequence_len=10
)
chosen_len = len(
self.tokenizer(result["chosen"], add_special_tokens=False)["input_ids"]
)
self.assertEqual(chosen_len, 5)
def test_does_not_mutate_original(self):
"""Verify immutability — original sample dict is not modified."""
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=10)
original_chosen = sample["chosen"]
original_rejected = sample["rejected"]
_truncate_long_sequences_rl(sample, RLType.DPO, self.tokenizer, sequence_len=10)
self.assertEqual(sample["chosen"], original_chosen)
self.assertEqual(sample["rejected"], original_rejected)
def test_dpo_missing_keys_raises(self):
with self.assertRaises(ValueError):
_truncate_long_sequences_rl(
{"prompt": "hi"}, RLType.DPO, self.tokenizer, 10
)
def test_kto_missing_keys_raises(self):
with self.assertRaises(ValueError):
_truncate_long_sequences_rl(
{"prompt": "hi"}, RLType.KTO, self.tokenizer, 10
)
def test_boundary_no_truncation_needed(self):
"""Samples exactly at sequence_len should not be modified."""
sample = _make_dpo_sample(prompt_len=5, chosen_len=5, rejected_len=5)
result = _truncate_long_sequences_rl(
sample, RLType.DPO, self.tokenizer, sequence_len=10
)
self.assertEqual(result["chosen"], sample["chosen"])
self.assertEqual(result["rejected"], sample["rejected"])

View File

@@ -2,6 +2,7 @@ import json
import math
from unittest.mock import Mock, patch
import pytest
import safetensors.torch
import torch
@@ -773,8 +774,8 @@ class TestEfficientMerge:
"v_proj should be unchanged (no LoRA weights for it)"
)
def test_dora_missing_magnitude_falls_back(self):
"""DoRA without magnitude vector falls back to standard LoRA merge."""
def test_dora_missing_magnitude_raises(self):
"""DoRA with missing magnitude vector raises an explicit error."""
hidden = 16
r = 4
alpha = 8
@@ -791,11 +792,13 @@ class TestEfficientMerge:
}
config = {"r": r, "lora_alpha": alpha, "use_dora": True}
merged, was_merged = _merge_tensor_with_lora(
base, "layer.proj.weight", lora_state, scale, config, "cpu", use_dora=True
)
assert was_merged
# No magnitude vector → PEFT creates DoRA layer but with default magnitude,
# which produces a result different from plain W + scale * B @ A.
# Just verify it was merged (not unchanged).
assert not torch.equal(merged, base)
with pytest.raises(ValueError, match="DoRA merge requires a magnitude vector"):
_merge_tensor_with_lora(
base,
"layer.proj.weight",
lora_state,
scale,
config,
"cpu",
use_dora=True,
)