update doc snippets + reject gemma4-hybrid with non-FA2 backend

This commit is contained in:
Wing Lian
2026-04-23 22:18:02 +00:00
parent 39226623d2
commit 434a484fe9
10 changed files with 47 additions and 27 deletions

View File

@@ -121,11 +121,11 @@ Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2
| Backend | Config | head_dim limit | torch_compile | Notes |
|---------|--------|---------------|---------------|-------|
| FA2 | `flash_attention: true` | 256 | ✅ | Fastest when supported |
| FA4 | auto with `flash_attention: true` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
| SDPA | `sdp_attention: true` | None | ✅ | Universal fallback |
| flex | `flex_attention: true` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
| eager | neither set | None | ✅ | Slowest, always works |
| FA2 | `attn_implementation: flash_attention_2` | 256 | ✅ | Fastest when supported |
| FA4 | auto with `attn_implementation: flash_attention_2` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
| SDPA | `attn_implementation: sdpa` | None | ✅ | Universal fallback |
| flex | `attn_implementation: flex_attention` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
| eager | `attn_implementation: eager` | None | ✅ | Slowest, always works |
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.

View File

@@ -83,7 +83,7 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
| Issue | Fix |
|-------|-----|
| OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` |
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `flash_attention: true` or disable `sample_packing` |
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `attn_implementation: flash_attention_2` or disable `sample_packing` |
| Missing chat template error | Set `chat_template: chatml` explicitly |
| Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels |
| Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples |

View File

@@ -129,7 +129,7 @@ gradient_accumulation_steps: 4
max_steps: 20
learning_rate: 5.0e-6
bf16: auto
flash_attention: true
attn_implementation: flash_attention_2
gradient_checkpointing: true
output_dir: ./outputs/ebft-quickstart
```
@@ -304,7 +304,7 @@ lora_alpha: 32
lora_target_linear: true
bf16: auto
flex_attention: true
attn_implementation: flex_attention
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true # Required with flex_attention

View File

@@ -154,7 +154,7 @@ lr_scheduler: cosine
warmup_steps: 10
bf16: true
flash_attention: true
attn_implementation: flash_attention_2
gradient_checkpointing: true
special_tokens:

View File

@@ -22,12 +22,12 @@ Improves GPU utilization by combining multiple short sequences into a single pac
Using an optimized attention implementation is critical for training speed.
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `attn_implementation: flash_attention_2`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `attn_implementation: flex_attention`.
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `attn_implementation: sdpa`. PyTorch's native implementation.
- **[Xformers](https://github.com/facebookresearch/xformers)**: `attn_implementation: xformers`. Works with FP16.
*Note: You should only enable one attention backend.*
See [Attention](attention.qmd) for the full list of backends and the canonical values.
### LoRA Optimizations

View File

@@ -1147,8 +1147,7 @@ datasets:
type: ebft_strided_structured.transform
split: train[:1%]
flash_attention: false
flex_attention: true # Strided mode uses flex_attention
attn_implementation: flex_attention # Strided mode uses flex_attention
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true # Required for flex_attention

View File

@@ -55,7 +55,7 @@ To use sequence parallelism, you need:
## Limitations
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
- Flash attention must be enabled for this to work (`attn_implementation: flash_attention_2` in config YAML)
- May have a small performance overhead due to communication between GPUs
## Example

View File

@@ -245,7 +245,7 @@ For GRPO, also reduce `max_completion_length`. Memory scales quadratically with
Reduces attention memory from O(n^2) to O(n):
```yaml
flash_attention: true
attn_implementation: flash_attention_2
```
### Step 6: Offload with DeepSpeed

View File

@@ -1413,12 +1413,20 @@ class AxolotlInputConfig(
attn_impl = data.get("attn_implementation")
set_flags = [f for f in LEGACY_ATTN_FLAG_TO_IMPL if data.get(f)]
# gemma4_hybrid defaults to flash_attention_2 when user didn't pick a
# backend. The sliding-window layers run under FA2; post-load patching
# swaps global layers to sdpa (see `_apply_gemma_hybrid_attention`).
if data.get("gemma4_hybrid_attn_impl") and not attn_impl and not set_flags:
data["attn_implementation"] = "flash_attention_2"
attn_impl = "flash_attention_2"
# gemma4_hybrid requires flash_attention_2 for the sliding-window layers;
# post-load patching swaps global layers to sdpa (see
# `_apply_gemma_hybrid_attention`). Default it in when the user didn't
# pick a backend; reject any incompatible explicit choice.
if data.get("gemma4_hybrid_attn_impl"):
if not attn_impl and not set_flags:
data["attn_implementation"] = "flash_attention_2"
attn_impl = "flash_attention_2"
elif attn_impl and attn_impl != "flash_attention_2":
raise ValueError(
f"gemma4_hybrid_attn_impl requires attn_implementation="
f"flash_attention_2 (sliding-window layers run under FA2); "
f"got {attn_impl!r}."
)
if attn_impl and set_flags:
raise ValueError(

View File

@@ -116,11 +116,24 @@ class TestGemma4Hybrid:
)
assert result["attn_implementation"] == "flash_attention_2"
def test_gemma4_hybrid_respects_explicit(self):
def test_gemma4_hybrid_with_incompatible_impl_raises(self):
"""Setting gemma4_hybrid alongside a non-FA2 attn_implementation is a
configuration error — the hybrid path requires FA2 under the hood."""
with pytest.raises(
ValueError, match="requires attn_implementation=flash_attention_2"
):
AxolotlInputConfig.normalize_attn_implementation(
{"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"}
)
def test_gemma4_hybrid_with_explicit_fa2_passes(self):
result = AxolotlInputConfig.normalize_attn_implementation(
{"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"}
{
"gemma4_hybrid_attn_impl": True,
"attn_implementation": "flash_attention_2",
}
)
assert result["attn_implementation"] == "sdpa"
assert result["attn_implementation"] == "flash_attention_2"
class TestFieldValidator: