241 lines
6.6 KiB
Plaintext
241 lines
6.6 KiB
Plaintext
---
|
|
title: Attention
|
|
description: Supported attention modules in Axolotl
|
|
---
|
|
|
|
Axolotl routes attention via a single config field:
|
|
|
|
```yaml
|
|
attn_implementation: <backend>
|
|
```
|
|
|
|
`attn_implementation` is passed through to `transformers` verbatim (via
|
|
`model.config._attn_implementation`). Accepted values are the HF-native
|
|
backends, axolotl-registered backends, or a hub-kernel path.
|
|
|
|
## Backends
|
|
|
|
| `attn_implementation` | Description |
|
|
|---|---|
|
|
| `eager` | Plain PyTorch attention. No packing support. |
|
|
| `sdpa` | PyTorch `scaled_dot_product_attention`. No packing support. |
|
|
| `flash_attention_2` | Dao-AILab Flash Attention 2. |
|
|
| `flash_attention_3` | Dao-AILab Flash Attention 3 (Hopper+). |
|
|
| `flex_attention` | Torch Flex Attention (requires torch ≥ 2.6). |
|
|
| `xformers` | xFormers memory-efficient attention. |
|
|
| `sage` | SageAttention (QK int8 / PV fp16). |
|
|
| `s2` | Shifted-Sparse Attention (LLaMA only, FA2 under the hood). |
|
|
| `fp8` | torchao FP8 low-precision attention (requires SM90+, torch ≥ 2.11). Loaded as SDPA and patched post-load. |
|
|
| `kernels-community/flash-attn3` | HF hub FA3 kernel. |
|
|
| `kernels-community/sage-attention` | HF hub SageAttention kernel. |
|
|
| Other `<org>/<name>` path | Any hub-kernel path supported by `transformers`. |
|
|
|
|
Short-form aliases (`flash`, `fa2`, `flex`, `sdp`, etc.) are **not accepted** —
|
|
set the canonical name above.
|
|
|
|
### Capability flags
|
|
|
|
Axolotl derives three boolean capability flags from `attn_implementation` and
|
|
exposes them on the validated config:
|
|
|
|
- `cfg.attn_supports_packing` — backend supports varlen sample packing via
|
|
`position_ids`. Gates multipack patches and `sample_packing_drop_attention_mask`.
|
|
- `cfg.attn_uses_flash_lib` — backend needs the `flash_attn` (Dao-AILab)
|
|
monkeypatches (FA4 auto, LLaMA flash hijack, ring-FA).
|
|
- `cfg.attn_needs_dtype_cast` — backend requires fp16/bf16 embeddings
|
|
(everything except `eager` and `sdpa`).
|
|
|
|
These are **computed** — they cannot be overridden from YAML.
|
|
|
|
## Per-backend notes
|
|
|
|
### SDPA
|
|
|
|
Default PyTorch attention. See
|
|
[PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
|
|
|
```yaml
|
|
attn_implementation: sdpa
|
|
```
|
|
|
|
### Flash Attention
|
|
|
|
Axolotl supports FA2, FA3, and FA4. The best available version is used
|
|
automatically based on your installed packages and GPU.
|
|
|
|
```yaml
|
|
attn_implementation: flash_attention_2 # or flash_attention_3
|
|
```
|
|
|
|
#### Flash Attention 2
|
|
|
|
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
|
|
|
|
```bash
|
|
pip install flash-attn --no-build-isolation
|
|
```
|
|
|
|
::: {.callout-tip}
|
|
|
|
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl.
|
|
Alternatively, try reinstall or downgrade a version.
|
|
|
|
:::
|
|
|
|
#### Flash Attention 3
|
|
|
|
Requirements: Hopper only and CUDA 12.8 (recommended)
|
|
|
|
```bash
|
|
git clone https://github.com/Dao-AILab/flash-attention.git
|
|
cd flash-attention/hopper
|
|
python setup.py install
|
|
```
|
|
|
|
#### Flash Attention 4
|
|
|
|
Requirements: Hopper or Blackwell GPUs. Auto-applied when `attn_uses_flash_lib`
|
|
is true and FA4 is importable.
|
|
|
|
```bash
|
|
pip install flash-attn-4
|
|
```
|
|
|
|
Or from source:
|
|
|
|
```bash
|
|
git clone https://github.com/Dao-AILab/flash-attention.git
|
|
cd flash-attention/flash_attn/cute
|
|
pip install -e .
|
|
|
|
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
|
|
# Remove it so Python can find the real FA4 module:
|
|
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute
|
|
```
|
|
|
|
::: {.callout-note}
|
|
|
|
**Hopper (SM90) users**: The backward kernel is not yet included in the pip package. To use FA4
|
|
for training on Hopper, install from source using the instructions above.
|
|
|
|
:::
|
|
|
|
::: {.callout-warning}
|
|
|
|
FA4 only supports head dimensions up to 128 (`d ≤ 128`). The DeepSeek shape `(192, 128)` is
|
|
also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions
|
|
and falls back to FA2/3.
|
|
|
|
:::
|
|
|
|
### AMD
|
|
|
|
Requirements: ROCm 6.0 and above. See
|
|
[Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
|
|
|
|
### Flex Attention
|
|
|
|
```yaml
|
|
attn_implementation: flex_attention
|
|
torch_compile: true # recommended
|
|
```
|
|
|
|
Requires torch ≥ 2.6. See [PyTorch docs](https://pytorch.org/blog/flexattention/).
|
|
|
|
### SageAttention
|
|
|
|
Requirements: Ampere, Ada, or Hopper GPUs.
|
|
|
|
```yaml
|
|
attn_implementation: sage
|
|
```
|
|
|
|
```bash
|
|
pip install sageattention==2.2.0 --no-build-isolation
|
|
```
|
|
|
|
::: {.callout-warning}
|
|
|
|
Only LoRA/QLoRA recommended. Full finetuning has been observed to drop loss to 0. See
|
|
[GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
|
|
|
|
:::
|
|
|
|
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention).
|
|
|
|
### xFormers
|
|
|
|
```yaml
|
|
attn_implementation: xformers
|
|
```
|
|
|
|
::: {.callout-tip}
|
|
|
|
Recommended for Turing GPUs or below (e.g. Colab T4).
|
|
|
|
:::
|
|
|
|
### Shifted Sparse Attention
|
|
|
|
::: {.callout-warning}
|
|
|
|
Planned for deprecation. Prefer one of the backends above.
|
|
|
|
:::
|
|
|
|
Requirements: LLaMA model architecture. Loaded as FA2 under the hood and
|
|
patched to implement shifted-sparse attention. Does not support sample packing.
|
|
|
|
```yaml
|
|
attn_implementation: s2
|
|
```
|
|
|
|
### FP8
|
|
|
|
torchao low-precision attention. Loaded as SDPA and patched post-load.
|
|
|
|
Requirements: SM90+ (Hopper/Blackwell), PyTorch ≥ 2.11, torchao ≥ 0.17,
|
|
flash-attn with FA3. KV caching must be disabled.
|
|
|
|
```yaml
|
|
attn_implementation: fp8
|
|
```
|
|
|
|
### Hub kernels
|
|
|
|
```yaml
|
|
attn_implementation: kernels-community/flash-attn3
|
|
```
|
|
|
|
Passed through to `transformers`; axolotl does not install the kernel itself.
|
|
For recognized hub paths the capability flags are set automatically; for
|
|
arbitrary paths axolotl uses conservative defaults (`attn_supports_packing=False`,
|
|
`attn_uses_flash_lib=False`).
|
|
|
|
## Migrating from legacy boolean flags
|
|
|
|
The following legacy config fields are **deprecated** and will be removed in a
|
|
future release. Each emits a `DeprecationWarning` when set and is stripped from
|
|
the validated config.
|
|
|
|
| Legacy | Canonical |
|
|
|---|---|
|
|
| `flash_attention: true` | `attn_implementation: flash_attention_2` |
|
|
| `sdp_attention: true` | `attn_implementation: sdpa` |
|
|
| `xformers_attention: true` | `attn_implementation: xformers` |
|
|
| `flex_attention: true` | `attn_implementation: flex_attention` |
|
|
| `sage_attention: true` | `attn_implementation: sage` |
|
|
| `s2_attention: true` | `attn_implementation: s2` |
|
|
| `eager_attention: true` | `attn_implementation: eager` |
|
|
|
|
Combining `attn_implementation` with a legacy flag (e.g. `attn_implementation:
|
|
flash_attention_2` **and** `flash_attention: true`) raises — pick one.
|
|
|
|
::: {.callout-note}
|
|
|
|
Existing example configs under `examples/` still use the legacy flags. They
|
|
continue to work with a deprecation warning; they will be migrated in a
|
|
follow-up pass.
|
|
|
|
:::
|