diff --git a/docs/agents/new_model_support.md b/docs/agents/new_model_support.md index 8e6028896..bc42ada86 100644 --- a/docs/agents/new_model_support.md +++ b/docs/agents/new_model_support.md @@ -121,11 +121,11 @@ Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2 | Backend | Config | head_dim limit | torch_compile | Notes | |---------|--------|---------------|---------------|-------| -| FA2 | `flash_attention: true` | 256 | ✅ | Fastest when supported | -| FA4 | auto with `flash_attention: true` | 256 (SM90+) | ✅ | Auto-detected on H100+ | -| SDPA | `sdp_attention: true` | None | ✅ | Universal fallback | -| flex | `flex_attention: true` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims | -| eager | neither set | None | ✅ | Slowest, always works | +| FA2 | `attn_implementation: flash_attention_2` | 256 | ✅ | Fastest when supported | +| FA4 | auto with `attn_implementation: flash_attention_2` | 256 (SM90+) | ✅ | Auto-detected on H100+ | +| SDPA | `attn_implementation: sdpa` | None | ✅ | Universal fallback | +| flex | `attn_implementation: flex_attention` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims | +| eager | `attn_implementation: eager` | None | ✅ | Slowest, always works | **Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class. diff --git a/docs/agents/sft.md b/docs/agents/sft.md index d3dfd39f7..f601cb0f5 100644 --- a/docs/agents/sft.md +++ b/docs/agents/sft.md @@ -83,7 +83,7 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go | Issue | Fix | |-------|-----| | OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` | -| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `flash_attention: true` or disable `sample_packing` | +| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `attn_implementation: flash_attention_2` or disable `sample_packing` | | Missing chat template error | Set `chat_template: chatml` explicitly | | Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels | | Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples | diff --git a/docs/ebft.qmd b/docs/ebft.qmd index eb7c95eca..d9afc3307 100644 --- a/docs/ebft.qmd +++ b/docs/ebft.qmd @@ -129,7 +129,7 @@ gradient_accumulation_steps: 4 max_steps: 20 learning_rate: 5.0e-6 bf16: auto -flash_attention: true +attn_implementation: flash_attention_2 gradient_checkpointing: true output_dir: ./outputs/ebft-quickstart ``` @@ -304,7 +304,7 @@ lora_alpha: 32 lora_target_linear: true bf16: auto -flex_attention: true +attn_implementation: flex_attention gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: true # Required with flex_attention diff --git a/docs/grpo.qmd b/docs/grpo.qmd index 35631f136..a98dbe11d 100644 --- a/docs/grpo.qmd +++ b/docs/grpo.qmd @@ -154,7 +154,7 @@ lr_scheduler: cosine warmup_steps: 10 bf16: true -flash_attention: true +attn_implementation: flash_attention_2 gradient_checkpointing: true special_tokens: diff --git a/docs/optimizations.qmd b/docs/optimizations.qmd index b180387ed..720519ec0 100644 --- a/docs/optimizations.qmd +++ b/docs/optimizations.qmd @@ -22,12 +22,12 @@ Improves GPU utilization by combining multiple short sequences into a single pac Using an optimized attention implementation is critical for training speed. -- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support). -- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`. -- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation. -- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16. +- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `attn_implementation: flash_attention_2`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support). +- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `attn_implementation: flex_attention`. +- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `attn_implementation: sdpa`. PyTorch's native implementation. +- **[Xformers](https://github.com/facebookresearch/xformers)**: `attn_implementation: xformers`. Works with FP16. -*Note: You should only enable one attention backend.* +See [Attention](attention.qmd) for the full list of backends and the canonical values. ### LoRA Optimizations diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index 75d20414c..a27bb2966 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -1147,8 +1147,7 @@ datasets: type: ebft_strided_structured.transform split: train[:1%] -flash_attention: false -flex_attention: true # Strided mode uses flex_attention +attn_implementation: flex_attention # Strided mode uses flex_attention gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: true # Required for flex_attention diff --git a/docs/sequence_parallelism.qmd b/docs/sequence_parallelism.qmd index d1933a145..9799c8a70 100644 --- a/docs/sequence_parallelism.qmd +++ b/docs/sequence_parallelism.qmd @@ -55,7 +55,7 @@ To use sequence parallelism, you need: ## Limitations -- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML) +- Flash attention must be enabled for this to work (`attn_implementation: flash_attention_2` in config YAML) - May have a small performance overhead due to communication between GPUs ## Example diff --git a/docs/training_stability.qmd b/docs/training_stability.qmd index e2cd79f89..9849a35d1 100644 --- a/docs/training_stability.qmd +++ b/docs/training_stability.qmd @@ -245,7 +245,7 @@ For GRPO, also reduce `max_completion_length`. Memory scales quadratically with Reduces attention memory from O(n^2) to O(n): ```yaml -flash_attention: true +attn_implementation: flash_attention_2 ``` ### Step 6: Offload with DeepSpeed diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 366576f0c..3e7fd5ec0 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1413,12 +1413,20 @@ class AxolotlInputConfig( attn_impl = data.get("attn_implementation") set_flags = [f for f in LEGACY_ATTN_FLAG_TO_IMPL if data.get(f)] - # gemma4_hybrid defaults to flash_attention_2 when user didn't pick a - # backend. The sliding-window layers run under FA2; post-load patching - # swaps global layers to sdpa (see `_apply_gemma_hybrid_attention`). - if data.get("gemma4_hybrid_attn_impl") and not attn_impl and not set_flags: - data["attn_implementation"] = "flash_attention_2" - attn_impl = "flash_attention_2" + # gemma4_hybrid requires flash_attention_2 for the sliding-window layers; + # post-load patching swaps global layers to sdpa (see + # `_apply_gemma_hybrid_attention`). Default it in when the user didn't + # pick a backend; reject any incompatible explicit choice. + if data.get("gemma4_hybrid_attn_impl"): + if not attn_impl and not set_flags: + data["attn_implementation"] = "flash_attention_2" + attn_impl = "flash_attention_2" + elif attn_impl and attn_impl != "flash_attention_2": + raise ValueError( + f"gemma4_hybrid_attn_impl requires attn_implementation=" + f"flash_attention_2 (sliding-window layers run under FA2); " + f"got {attn_impl!r}." + ) if attn_impl and set_flags: raise ValueError( diff --git a/tests/test_attn_implementation.py b/tests/test_attn_implementation.py index 44628fd33..e0769ffc5 100644 --- a/tests/test_attn_implementation.py +++ b/tests/test_attn_implementation.py @@ -116,11 +116,24 @@ class TestGemma4Hybrid: ) assert result["attn_implementation"] == "flash_attention_2" - def test_gemma4_hybrid_respects_explicit(self): + def test_gemma4_hybrid_with_incompatible_impl_raises(self): + """Setting gemma4_hybrid alongside a non-FA2 attn_implementation is a + configuration error — the hybrid path requires FA2 under the hood.""" + with pytest.raises( + ValueError, match="requires attn_implementation=flash_attention_2" + ): + AxolotlInputConfig.normalize_attn_implementation( + {"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"} + ) + + def test_gemma4_hybrid_with_explicit_fa2_passes(self): result = AxolotlInputConfig.normalize_attn_implementation( - {"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"} + { + "gemma4_hybrid_attn_impl": True, + "attn_implementation": "flash_attention_2", + } ) - assert result["attn_implementation"] == "sdpa" + assert result["attn_implementation"] == "flash_attention_2" class TestFieldValidator: