* upgrade to torchao 0.17.0 * chore: lint * refactor attention handling * replace legacy attention boolean flags with capability properties Replace checks with capability-based properties derived from attn_implementation This separates three concerns that were conflated under flash_attention: 1. Backend selection -> attn_implementation enum 2. Packing capability -> attn_supports_packing property 3. Flash-attn library dependency -> attn_uses_flash_lib property * compute attn capability flags in normalizer instead of properties * make attn_implementation the single source of truth * move attention-dependent validators to mode=after * migrate remaining consumers to canonical attn_implementation * expand attention tests + rewrite docs * migrate example configs to canonical attn_implementation * update doc snippets + reject gemma4-hybrid with non-FA2 backend * remove dead gemma4 branch in _set_attention_config * fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests * drop "Phase 2" naming from attn-implementation tests * regroup attn_implementation tests by feature concern * clean up verbose comments and remove MD Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai> * fix(collator): pass return_dict=True at apply_chat_template top level for transformers 5.x In transformers 5.x, ProcessorMixin.apply_chat_template gained its own `return_dict` parameter (defaulting to False). When return_dict=False and tokenize=True the method returns out["input_ids"] directly — a 2-D tensor — rather than the full BatchFeature dict. The old code placed `return_dict=True` inside processor_kwargs. In transformers 5.x those kwargs are forwarded to the underlying processor call self(...) where _merge_kwargs silently ignores any key not present in MllamaProcessorKwargs (emitting a warning). The outer return_dict therefore stayed False, apply_chat_template returned the raw input_ids tensor, and the subsequent `batch["input_ids"]` attempted to index a 2-D tensor with the 9-character string "input_ids", producing: IndexError: too many indices for tensor of dimension 2 The fix is to pass return_dict=True as a top-level keyword argument to apply_chat_template (where it is actually consumed) and remove it from processor_kwargs (where it was silently dropped). No version guard is needed: transformers is pinned to ==5.5.4 in pyproject.toml. Adds a unit-level regression test (tests/test_mm_chat_collator.py) that mocks the processor to return a raw tensor when apply_chat_template is called without top-level return_dict=True, verifying the four invariants: process_rows returns a dict, input_ids is 2-D, labels is 2-D, and apply_chat_template receives return_dict=True as a top-level kwarg. Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_multimodal_dataset Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_text_only_dataset Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai> * fix(collator): process_rows returns dict (BatchFeature) shape Two related changes for the multimodal chat collator under transformers 5.x: 1. Wrap apply_chat_template result in dict(...) so process_rows returns a plain dict rather than a BatchFeature instance. BatchFeature is a Mapping but not a dict; downstream code that did batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"]) would index on a tensor when the result wasn't dict-shaped, raising IndexError: too many indices for tensor of dimension 2 2. Soften the regression test's contract from `dict` to `Mapping` so it exercises the actual semantic guarantee (key/value access) rather than the implementation detail (dict vs BatchFeature). Test guards against the original transformers 5.x breakage where apply_chat_template's return_dict default went from True to False. Includes regression test under tests/test_mm_chat_collator.py. Bug surfaced via swarm dispatch task_01KQHPNAYD8XARSNSDJVW1GPF6 against attn-implementation-refactor; squash-merged from agent commits 4de886fd + dc9fcf4f. Signed-off-by: Wing Lian <wing@axolotl.ai> --------- Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
182 lines
9.5 KiB
Markdown
182 lines
9.5 KiB
Markdown
# New Model Support — Agent Reference
|
||
|
||
Guide for debugging and adding support for new model architectures in axolotl. Based on lessons learned from Gemma4, Gemma3, Qwen2-VL, and other multimodal/MoE models.
|
||
|
||
## Quick Validation Checklist
|
||
|
||
When testing a new model, run through these checks in order:
|
||
|
||
1. **Does the model load?** `axolotl preprocess config.yaml` — catches config schema errors
|
||
2. **Does LoRA apply?** Check for "Unsupported layer type" warnings from PEFT
|
||
3. **Is the initial loss sane?** First-step loss for a pretrained model should be 0.5–2.0 for SFT
|
||
4. **Does sample packing work?** Compare loss with `sample_packing: true` vs `false` — should be similar
|
||
5. **Is CCE active?** Check for "Applying Cut Cross Entropy" log and verify peak VRAM is lower
|
||
|
||
## Loss Debugging
|
||
|
||
### Expected initial loss
|
||
A pretrained model doing SFT should start with loss roughly in the 0.5–2.0 range. If loss starts above 3.0, something is wrong. If it's near `log(vocab_size)` (≈ 12 for 262K vocab), the model is predicting at random — attention masking or model weights are broken.
|
||
|
||
### Direct comparison technique
|
||
The fastest way to isolate a loss issue — bypass the trainer entirely:
|
||
|
||
```python
|
||
# Load model via axolotl's pipeline (applies all patches)
|
||
from axolotl.cli.config import load_cfg
|
||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||
from axolotl.loaders.tokenizer import load_tokenizer
|
||
from axolotl.loaders.model import ModelLoader
|
||
|
||
cfg = load_cfg("your_config.yaml")
|
||
normalize_config(cfg)
|
||
prepare_plugins(cfg)
|
||
tokenizer = load_tokenizer(cfg)
|
||
model, _ = ModelLoader(cfg, tokenizer).load()
|
||
|
||
# Forward pass on preprocessed data
|
||
model.train()
|
||
out = model(input_ids, labels=labels)
|
||
print(f"Direct loss: {out.loss.item()}") # Compare to trainer's reported loss
|
||
```
|
||
|
||
If direct loss is correct (~1.0) but trainer reports 3–4x higher, check `model_accepts_loss_kwargs` (see below).
|
||
|
||
### `model_accepts_loss_kwargs` inflation
|
||
HF Trainer checks if the model's `forward()` has `**kwargs` and sets `model_accepts_loss_kwargs=True`. This changes loss normalization: the trainer does NOT divide loss by `gradient_accumulation_steps` before logging. The gradient is correct — only the logged loss is inflated.
|
||
|
||
**Symptom**: Logged loss ≈ actual_loss × gradient_accumulation_steps.
|
||
|
||
**Which models are affected**: Any model with `**kwargs` in forward (common in multimodal models for extra inputs like `mm_token_type_ids`, `pixel_values`, etc.).
|
||
|
||
**Fix location**: `src/axolotl/core/trainers/base.py` `__init__()` — after `super().__init__()`, check if the unwrapped model actually has `num_items_in_batch` in its forward signature. If not, set `self.model_accepts_loss_kwargs = False`.
|
||
|
||
## Multimodal Models (ForConditionalGeneration)
|
||
|
||
Many recent models use `ForConditionalGeneration` as the top-level class, not `ForCausalLM`:
|
||
- Gemma3 → `Gemma3ForConditionalGeneration`
|
||
- Gemma4 → `Gemma4ForConditionalGeneration`
|
||
- Qwen2-VL → `Qwen2VLForConditionalGeneration`
|
||
- LLaVA → `LlavaForConditionalGeneration`
|
||
|
||
### Why this matters
|
||
|
||
| Component | Targets `ForCausalLM` | Needs `ForConditionalGeneration` |
|
||
|-----------|----------------------|--------------------------------|
|
||
| CCE patches | ✅ (default) | ❌ silently inactive if not patched |
|
||
| PEFT LoRA | ✅ | May fail on custom layer types |
|
||
| HF Trainer label handling | ✅ | May need extra inputs |
|
||
|
||
### Required extra inputs
|
||
Multimodal models require special inputs during training even for text-only data:
|
||
|
||
| Model | Required Input | Value for Text-Only |
|
||
|-------|---------------|-------------------|
|
||
| Gemma4 | `mm_token_type_ids` | `torch.zeros_like(input_ids)` |
|
||
| Gemma3 | `token_type_ids` | `torch.zeros_like(input_ids)` |
|
||
|
||
Auto-inject in `compute_loss()` when not provided by the data collator. See `core/trainers/base.py`.
|
||
|
||
### Custom layer types and PEFT
|
||
Vision towers often use custom module wrappers that PEFT doesn't support:
|
||
|
||
| Model | Custom Layer | Wraps | Fix |
|
||
|-------|-------------|-------|-----|
|
||
| Gemma4 | `Gemma4ClippableLinear` | `nn.Linear` | Redirect to `.linear` child |
|
||
|
||
Fix location: `src/axolotl/loaders/adapter.py` `_patch_peft_clippable_linear()`.
|
||
|
||
## Sample Packing
|
||
|
||
### How packed sequence detection works (transformers ≥ 5.x)
|
||
`transformers.masking_utils._preprocess_mask_arguments()` detects packed sequences from `position_ids` resets. But **only when `attention_mask is None`**:
|
||
|
||
```python
|
||
# From masking_utils.py:
|
||
if position_ids is not None and attention_mask is None and past_key_values is None:
|
||
packed_sequence_mask = find_packed_sequence_indices(position_ids)
|
||
```
|
||
|
||
If the collator provides an all-ones `attention_mask`, packing detection is **skipped** and the model builds a single causal mask spanning all packed sequences → cross-sequence attention leakage → very high loss.
|
||
|
||
### Fix for models using `create_causal_mask_mapping`
|
||
For Gemma3, Gemma4, and similar models that use the new transformers masking system, remove `attention_mask` from inputs when sample packing is active:
|
||
|
||
```python
|
||
# In compute_loss():
|
||
if (
|
||
self.args.sample_packing
|
||
and model_type in ("gemma4", "gemma3")
|
||
and "attention_mask" in inputs
|
||
and "position_ids" in inputs
|
||
):
|
||
del inputs["attention_mask"]
|
||
```
|
||
|
||
Fix location: `src/axolotl/core/trainers/base.py` `compute_loss()`.
|
||
|
||
### Models that DON'T need this fix
|
||
Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2, etc.) handle sample packing via axolotl's multipack attention monkeypatch instead. Only models using the new `create_causal_mask_mapping` / `create_causal_mask` masking system need the `attention_mask` removal.
|
||
|
||
## Attention Backend Selection
|
||
|
||
| Backend | Config | head_dim limit | torch_compile | Notes |
|
||
|---------|--------|---------------|---------------|-------|
|
||
| FA2 | `attn_implementation: flash_attention_2` | 256 | ✅ | Fastest when supported |
|
||
| FA4 | auto with `attn_implementation: flash_attention_2` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
|
||
| SDPA | `attn_implementation: sdpa` | None | ✅ | Universal fallback |
|
||
| flex | `attn_implementation: flex_attention` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
|
||
| eager | `attn_implementation: eager` | None | ✅ | Slowest, always works |
|
||
|
||
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.
|
||
|
||
**head_dim gotcha**: The 256 limit is specific to flash-attn CUDA kernels, NOT PyTorch-level. SDPA and flex_attention both handle arbitrary head_dim. Models with `global_head_dim > 256` (Gemma4: 512) must use SDPA or flex.
|
||
|
||
**flex + compile gotcha**: `torch_compile` with flex_attention can hit Triton shared memory OOM for large head_dim. Falls back to eager per-function (not a crash, but slower). Unsloth disables flex for Gemma4 for this reason.
|
||
|
||
## Cut Cross Entropy (CCE)
|
||
|
||
### How CCE patches work
|
||
CCE replaces the model's `forward()` with a fused version that computes loss from hidden states + lm_head weight without materializing the full logits tensor. This saves ~`batch × seq_len × vocab_size × dtype_bytes` of VRAM.
|
||
|
||
### Adding CCE for a new model
|
||
1. Check if the model type is in `cut_cross_entropy.transformers.patch.PATCH_FNS`
|
||
2. If not, axolotl's generic fallback (`integrations/cut_cross_entropy/__init__.py` `patch_llama_like()`) patches `{Prefix}ForCausalLM.forward` with `cce_forward`
|
||
3. For multimodal models (`ForConditionalGeneration`), a model-specific patch is needed in `ml-cross-entropy` repo
|
||
4. The multimodal `cce_forward` must accept all extra kwargs (pixel_values, mm_token_type_ids, etc.) and pop any that would conflict before calling `self.model()`
|
||
|
||
### Common CCE pitfall
|
||
If CCE appears active (log says "Applying Cut Cross Entropy") but peak VRAM doesn't decrease, check which class was patched. If the model loads as `ForConditionalGeneration` but CCE patched `ForCausalLM`, the patch is silently inactive.
|
||
|
||
## MoE Models
|
||
|
||
### Dense MLP vs MoE experts
|
||
Some MoE models (e.g., Gemma4) have BOTH dense MLP layers and MoE expert layers at every decoder layer:
|
||
- `gate_proj/up_proj/down_proj` → targets the **dense MLP** (`Gemma4TextMLP`)
|
||
- `experts.gate_up_proj/experts.down_proj` → targets the **MoE experts** (`Gemma4TextExperts`)
|
||
|
||
LoRA on the dense MLP works normally. Expert LoRA via `lora_target_parameters` requires PEFT support for the specific expert module type (may warn "Unsupported layer type").
|
||
|
||
### ScatterMoE kernels
|
||
`use_scattermoe: true` with `experts_implementation: scattermoe` registers fused expert kernels via transformers' `ExpertsInterface`. Significant speedup for MoE models. Requires the kernels plugin:
|
||
```yaml
|
||
plugins:
|
||
- axolotl.integrations.kernels.KernelsPlugin
|
||
use_kernels: true
|
||
use_scattermoe: true
|
||
experts_implementation: scattermoe
|
||
```
|
||
|
||
## Where to Add Model-Specific Fixes
|
||
|
||
| What | Where | Example |
|
||
|------|-------|---------|
|
||
| Missing forward inputs | `core/trainers/base.py` `compute_loss()` | mm_token_type_ids injection |
|
||
| Attention mask fixes | `core/trainers/base.py` `compute_loss()` | Sample packing mask removal |
|
||
| Loss logging fixes | `core/trainers/base.py` `__init__()` | model_accepts_loss_kwargs override |
|
||
| PEFT/LoRA patches | `loaders/adapter.py` | ClippableLinear redirect |
|
||
| Attention patches | `monkeypatch/attention/` | FA4 tuple fix |
|
||
| Model-specific patches | `loaders/patch_manager.py` `_apply_model_specific_patches()` | Llama4, Kimi, NemotronH |
|
||
| CCE patches | `ml-cross-entropy` repo `transformers/` | Per-model cce_forward |
|
||
| Example configs | `examples/<model>/` | Validated YAML |
|
||
| Config validation | `utils/schemas/validation.py` | Compatibility checks |
|