expand attention tests + rewrite docs

This commit is contained in:
Wing Lian
2026-04-23 21:30:20 +00:00
parent a0d24bcc19
commit 2d64d009d8
2 changed files with 262 additions and 61 deletions

View File

@@ -3,28 +3,71 @@ title: Attention
description: Supported attention modules in Axolotl
---
## SDP Attention
This is the default built-in attention in PyTorch.
Axolotl routes attention via a single config field:
```yaml
sdp_attention: true
attn_implementation: <backend>
```
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
`attn_implementation` is passed through to `transformers` verbatim (via
`model.config._attn_implementation`). Accepted values are the HF-native
backends, axolotl-registered backends, or a hub-kernel path.
## Flash Attention
## Backends
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
based on your installed packages and GPU.
| `attn_implementation` | Description |
|---|---|
| `eager` | Plain PyTorch attention. No packing support. |
| `sdpa` | PyTorch `scaled_dot_product_attention`. No packing support. |
| `flash_attention_2` | Dao-AILab Flash Attention 2. |
| `flash_attention_3` | Dao-AILab Flash Attention 3 (Hopper+). |
| `flex_attention` | Torch Flex Attention (requires torch ≥ 2.6). |
| `xformers` | xFormers memory-efficient attention. |
| `sage` | SageAttention (QK int8 / PV fp16). |
| `s2` | Shifted-Sparse Attention (LLaMA only, FA2 under the hood). |
| `fp8` | torchao FP8 low-precision attention (requires SM90+, torch ≥ 2.11). Loaded as SDPA and patched post-load. |
| `kernels-community/flash-attn3` | HF hub FA3 kernel. |
| `kernels-community/sage-attention` | HF hub SageAttention kernel. |
| Other `<org>/<name>` path | Any hub-kernel path supported by `transformers`. |
Short-form aliases (`flash`, `fa2`, `flex`, `sdp`, etc.) are **not accepted** —
set the canonical name above.
### Capability flags
Axolotl derives three boolean capability flags from `attn_implementation` and
exposes them on the validated config:
- `cfg.attn_supports_packing` — backend supports varlen sample packing via
`position_ids`. Gates multipack patches and `sample_packing_drop_attention_mask`.
- `cfg.attn_uses_flash_lib` — backend needs the `flash_attn` (Dao-AILab)
monkeypatches (FA4 auto, LLaMA flash hijack, ring-FA).
- `cfg.attn_needs_dtype_cast` — backend requires fp16/bf16 embeddings
(everything except `eager` and `sdpa`).
These are **computed** — they cannot be overridden from YAML.
## Per-backend notes
### SDPA
Default PyTorch attention. See
[PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
```yaml
flash_attention: true
attn_implementation: sdpa
```
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
### Flash Attention
### Flash Attention 2
Axolotl supports FA2, FA3, and FA4. The best available version is used
automatically based on your installed packages and GPU.
```yaml
attn_implementation: flash_attention_2 # or flash_attention_3
```
#### Flash Attention 2
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
@@ -39,20 +82,20 @@ Alternatively, try reinstall or downgrade a version.
:::
### Flash Attention 3
#### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install
```
### Flash Attention 4
#### Flash Attention 4
Requirements: Hopper or Blackwell GPUs
Requirements: Hopper or Blackwell GPUs. Auto-applied when `attn_uses_flash_lib`
is true and FA4 is importable.
```bash
pip install flash-attn-4
@@ -63,7 +106,6 @@ Or from source:
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/flash_attn/cute
pip install -e .
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
@@ -86,93 +128,113 @@ and falls back to FA2/3.
:::
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
### AMD
Requirements: ROCm 6.0 and above.
Requirements: ROCm 6.0 and above. See
[Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
## Flex Attention
A flexible PyTorch API for attention used in combination with `torch.compile`.
### Flex Attention
```yaml
flex_attention: true
# recommended
torch_compile: true
attn_implementation: flex_attention
torch_compile: true # recommended
```
::: {.callout-note}
Requires torch ≥ 2.6. See [PyTorch docs](https://pytorch.org/blog/flexattention/).
We recommend using latest stable version of PyTorch for best performance.
### SageAttention
:::
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
## SageAttention
Attention kernels with QK Int8 and PV FP16 accumulator.
Requirements: Ampere, Ada, or Hopper GPUs.
```yaml
sage_attention: true
attn_implementation: sage
```
Requirements: Ampere, Ada, or Hopper GPUs
```bash
pip install sageattention==2.2.0 --no-build-isolation
```
::: {.callout-warning}
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
Only LoRA/QLoRA recommended. Full finetuning has been observed to drop loss to 0. See
[GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
:::
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention).
::: {.callout-note}
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
:::
## xFormers
### xFormers
```yaml
xformers_attention: true
attn_implementation: xformers
```
::: {.callout-tip}
We recommend using with Turing GPUs or below (such as on Colab).
Recommended for Turing GPUs or below (e.g. Colab T4).
:::
For more details: [xFormers](https://github.com/facebookresearch/xformers)
## Shifted Sparse Attention
### Shifted Sparse Attention
::: {.callout-warning}
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
Planned for deprecation. Prefer one of the backends above.
:::
Requirements: LLaMA model architecture
Requirements: LLaMA model architecture. Loaded as FA2 under the hood and
patched to implement shifted-sparse attention. Does not support sample packing.
```yaml
flash_attention: true
s2_attention: true
attn_implementation: s2
```
::: {.callout-tip}
### FP8
No sample packing support!
torchao low-precision attention. Loaded as SDPA and patched post-load.
Requirements: SM90+ (Hopper/Blackwell), PyTorch ≥ 2.11, torchao ≥ 0.17,
flash-attn with FA3. KV caching must be disabled.
```yaml
attn_implementation: fp8
```
### Hub kernels
```yaml
attn_implementation: kernels-community/flash-attn3
```
Passed through to `transformers`; axolotl does not install the kernel itself.
For recognized hub paths the capability flags are set automatically; for
arbitrary paths axolotl uses conservative defaults (`attn_supports_packing=False`,
`attn_uses_flash_lib=False`).
## Migrating from legacy boolean flags
The following legacy config fields are **deprecated** and will be removed in a
future release. Each emits a `DeprecationWarning` when set and is stripped from
the validated config.
| Legacy | Canonical |
|---|---|
| `flash_attention: true` | `attn_implementation: flash_attention_2` |
| `sdp_attention: true` | `attn_implementation: sdpa` |
| `xformers_attention: true` | `attn_implementation: xformers` |
| `flex_attention: true` | `attn_implementation: flex_attention` |
| `sage_attention: true` | `attn_implementation: sage` |
| `s2_attention: true` | `attn_implementation: s2` |
| `eager_attention: true` | `attn_implementation: eager` |
Combining `attn_implementation` with a legacy flag (e.g. `attn_implementation:
flash_attention_2` **and** `flash_attention: true`) raises — pick one.
::: {.callout-note}
Existing example configs under `examples/` still use the legacy flags. They
continue to work with a deprecation warning; they will be migrated in a
follow-up pass.
:::