update doc snippets + reject gemma4-hybrid with non-FA2 backend
This commit is contained in:
@@ -121,11 +121,11 @@ Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2
|
|||||||
|
|
||||||
| Backend | Config | head_dim limit | torch_compile | Notes |
|
| Backend | Config | head_dim limit | torch_compile | Notes |
|
||||||
|---------|--------|---------------|---------------|-------|
|
|---------|--------|---------------|---------------|-------|
|
||||||
| FA2 | `flash_attention: true` | 256 | ✅ | Fastest when supported |
|
| FA2 | `attn_implementation: flash_attention_2` | 256 | ✅ | Fastest when supported |
|
||||||
| FA4 | auto with `flash_attention: true` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
|
| FA4 | auto with `attn_implementation: flash_attention_2` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
|
||||||
| SDPA | `sdp_attention: true` | None | ✅ | Universal fallback |
|
| SDPA | `attn_implementation: sdpa` | None | ✅ | Universal fallback |
|
||||||
| flex | `flex_attention: true` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
|
| flex | `attn_implementation: flex_attention` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
|
||||||
| eager | neither set | None | ✅ | Slowest, always works |
|
| eager | `attn_implementation: eager` | None | ✅ | Slowest, always works |
|
||||||
|
|
||||||
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.
|
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.
|
||||||
|
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
|
|||||||
| Issue | Fix |
|
| Issue | Fix |
|
||||||
|-------|-----|
|
|-------|-----|
|
||||||
| OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` |
|
| OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` |
|
||||||
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `flash_attention: true` or disable `sample_packing` |
|
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `attn_implementation: flash_attention_2` or disable `sample_packing` |
|
||||||
| Missing chat template error | Set `chat_template: chatml` explicitly |
|
| Missing chat template error | Set `chat_template: chatml` explicitly |
|
||||||
| Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels |
|
| Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels |
|
||||||
| Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples |
|
| Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples |
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ gradient_accumulation_steps: 4
|
|||||||
max_steps: 20
|
max_steps: 20
|
||||||
learning_rate: 5.0e-6
|
learning_rate: 5.0e-6
|
||||||
bf16: auto
|
bf16: auto
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
output_dir: ./outputs/ebft-quickstart
|
output_dir: ./outputs/ebft-quickstart
|
||||||
```
|
```
|
||||||
@@ -304,7 +304,7 @@ lora_alpha: 32
|
|||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
|
|
||||||
bf16: auto
|
bf16: auto
|
||||||
flex_attention: true
|
attn_implementation: flex_attention
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
gradient_checkpointing_kwargs:
|
gradient_checkpointing_kwargs:
|
||||||
use_reentrant: true # Required with flex_attention
|
use_reentrant: true # Required with flex_attention
|
||||||
|
|||||||
@@ -154,7 +154,7 @@ lr_scheduler: cosine
|
|||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
|
|
||||||
bf16: true
|
bf16: true
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|||||||
@@ -22,12 +22,12 @@ Improves GPU utilization by combining multiple short sequences into a single pac
|
|||||||
|
|
||||||
Using an optimized attention implementation is critical for training speed.
|
Using an optimized attention implementation is critical for training speed.
|
||||||
|
|
||||||
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
|
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `attn_implementation: flash_attention_2`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
|
||||||
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.
|
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `attn_implementation: flex_attention`.
|
||||||
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.
|
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `attn_implementation: sdpa`. PyTorch's native implementation.
|
||||||
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.
|
- **[Xformers](https://github.com/facebookresearch/xformers)**: `attn_implementation: xformers`. Works with FP16.
|
||||||
|
|
||||||
*Note: You should only enable one attention backend.*
|
See [Attention](attention.qmd) for the full list of backends and the canonical values.
|
||||||
|
|
||||||
### LoRA Optimizations
|
### LoRA Optimizations
|
||||||
|
|
||||||
|
|||||||
@@ -1147,8 +1147,7 @@ datasets:
|
|||||||
type: ebft_strided_structured.transform
|
type: ebft_strided_structured.transform
|
||||||
split: train[:1%]
|
split: train[:1%]
|
||||||
|
|
||||||
flash_attention: false
|
attn_implementation: flex_attention # Strided mode uses flex_attention
|
||||||
flex_attention: true # Strided mode uses flex_attention
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
gradient_checkpointing_kwargs:
|
gradient_checkpointing_kwargs:
|
||||||
use_reentrant: true # Required for flex_attention
|
use_reentrant: true # Required for flex_attention
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ To use sequence parallelism, you need:
|
|||||||
|
|
||||||
## Limitations
|
## Limitations
|
||||||
|
|
||||||
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
|
- Flash attention must be enabled for this to work (`attn_implementation: flash_attention_2` in config YAML)
|
||||||
- May have a small performance overhead due to communication between GPUs
|
- May have a small performance overhead due to communication between GPUs
|
||||||
|
|
||||||
## Example
|
## Example
|
||||||
|
|||||||
@@ -245,7 +245,7 @@ For GRPO, also reduce `max_completion_length`. Memory scales quadratically with
|
|||||||
Reduces attention memory from O(n^2) to O(n):
|
Reduces attention memory from O(n^2) to O(n):
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
```
|
```
|
||||||
|
|
||||||
### Step 6: Offload with DeepSpeed
|
### Step 6: Offload with DeepSpeed
|
||||||
|
|||||||
@@ -1413,12 +1413,20 @@ class AxolotlInputConfig(
|
|||||||
attn_impl = data.get("attn_implementation")
|
attn_impl = data.get("attn_implementation")
|
||||||
set_flags = [f for f in LEGACY_ATTN_FLAG_TO_IMPL if data.get(f)]
|
set_flags = [f for f in LEGACY_ATTN_FLAG_TO_IMPL if data.get(f)]
|
||||||
|
|
||||||
# gemma4_hybrid defaults to flash_attention_2 when user didn't pick a
|
# gemma4_hybrid requires flash_attention_2 for the sliding-window layers;
|
||||||
# backend. The sliding-window layers run under FA2; post-load patching
|
# post-load patching swaps global layers to sdpa (see
|
||||||
# swaps global layers to sdpa (see `_apply_gemma_hybrid_attention`).
|
# `_apply_gemma_hybrid_attention`). Default it in when the user didn't
|
||||||
if data.get("gemma4_hybrid_attn_impl") and not attn_impl and not set_flags:
|
# pick a backend; reject any incompatible explicit choice.
|
||||||
data["attn_implementation"] = "flash_attention_2"
|
if data.get("gemma4_hybrid_attn_impl"):
|
||||||
attn_impl = "flash_attention_2"
|
if not attn_impl and not set_flags:
|
||||||
|
data["attn_implementation"] = "flash_attention_2"
|
||||||
|
attn_impl = "flash_attention_2"
|
||||||
|
elif attn_impl and attn_impl != "flash_attention_2":
|
||||||
|
raise ValueError(
|
||||||
|
f"gemma4_hybrid_attn_impl requires attn_implementation="
|
||||||
|
f"flash_attention_2 (sliding-window layers run under FA2); "
|
||||||
|
f"got {attn_impl!r}."
|
||||||
|
)
|
||||||
|
|
||||||
if attn_impl and set_flags:
|
if attn_impl and set_flags:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -116,11 +116,24 @@ class TestGemma4Hybrid:
|
|||||||
)
|
)
|
||||||
assert result["attn_implementation"] == "flash_attention_2"
|
assert result["attn_implementation"] == "flash_attention_2"
|
||||||
|
|
||||||
def test_gemma4_hybrid_respects_explicit(self):
|
def test_gemma4_hybrid_with_incompatible_impl_raises(self):
|
||||||
|
"""Setting gemma4_hybrid alongside a non-FA2 attn_implementation is a
|
||||||
|
configuration error — the hybrid path requires FA2 under the hood."""
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="requires attn_implementation=flash_attention_2"
|
||||||
|
):
|
||||||
|
AxolotlInputConfig.normalize_attn_implementation(
|
||||||
|
{"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_gemma4_hybrid_with_explicit_fa2_passes(self):
|
||||||
result = AxolotlInputConfig.normalize_attn_implementation(
|
result = AxolotlInputConfig.normalize_attn_implementation(
|
||||||
{"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"}
|
{
|
||||||
|
"gemma4_hybrid_attn_impl": True,
|
||||||
|
"attn_implementation": "flash_attention_2",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
assert result["attn_implementation"] == "sdpa"
|
assert result["attn_implementation"] == "flash_attention_2"
|
||||||
|
|
||||||
|
|
||||||
class TestFieldValidator:
|
class TestFieldValidator:
|
||||||
|
|||||||
Reference in New Issue
Block a user