--- title: Attention description: Supported attention modules in Axolotl --- Axolotl routes attention via a single config field: ```yaml attn_implementation: ``` `attn_implementation` is passed through to `transformers` verbatim (via `model.config._attn_implementation`). Accepted values are the HF-native backends, axolotl-registered backends, or a hub-kernel path. ## Backends | `attn_implementation` | Description | |---|---| | `eager` | Plain PyTorch attention. No packing support. | | `sdpa` | PyTorch `scaled_dot_product_attention`. No packing support. | | `flash_attention_2` | Dao-AILab Flash Attention 2. | | `flash_attention_3` | Dao-AILab Flash Attention 3 (Hopper+). | | `flex_attention` | Torch Flex Attention (requires torch ≥ 2.6). | | `xformers` | xFormers memory-efficient attention. | | `sage` | SageAttention (QK int8 / PV fp16). | | `s2` | Shifted-Sparse Attention (LLaMA only, FA2 under the hood). | | `fp8` | torchao FP8 low-precision attention (requires SM90+, torch ≥ 2.11). Loaded as SDPA and patched post-load. | | `kernels-community/flash-attn3` | HF hub FA3 kernel. | | `kernels-community/sage-attention` | HF hub SageAttention kernel. | | Other `/` path | Any hub-kernel path supported by `transformers`. | Short-form aliases (`flash`, `fa2`, `flex`, `sdp`, etc.) are **not accepted** — set the canonical name above. ### Capability flags Axolotl derives three boolean capability flags from `attn_implementation` and exposes them on the validated config: - `cfg.attn_supports_packing` — backend supports varlen sample packing via `position_ids`. Gates multipack patches and `sample_packing_drop_attention_mask`. - `cfg.attn_uses_flash_lib` — backend needs the `flash_attn` (Dao-AILab) monkeypatches (FA4 auto, LLaMA flash hijack, ring-FA). - `cfg.attn_needs_dtype_cast` — backend requires fp16/bf16 embeddings (everything except `eager` and `sdpa`). These are **computed** — they cannot be overridden from YAML. ## Per-backend notes ### SDPA Default PyTorch attention. See [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html). ```yaml attn_implementation: sdpa ``` ### Flash Attention Axolotl supports FA2, FA3, and FA4. The best available version is used automatically based on your installed packages and GPU. ```yaml attn_implementation: flash_attention_2 # or flash_attention_3 ``` #### Flash Attention 2 Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported) ```bash pip install flash-attn --no-build-isolation ``` ::: {.callout-tip} If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version. ::: #### Flash Attention 3 Requirements: Hopper only and CUDA 12.8 (recommended) ```bash git clone https://github.com/Dao-AILab/flash-attention.git cd flash-attention/hopper python setup.py install ``` #### Flash Attention 4 Requirements: Hopper or Blackwell GPUs. Auto-applied when `attn_uses_flash_lib` is true and FA4 is importable. ```bash pip install flash-attn-4 ``` Or from source: ```bash git clone https://github.com/Dao-AILab/flash-attention.git cd flash-attention/flash_attn/cute pip install -e . # FA2's flash_attn package includes a cute/ stub that shadows FA4. # Remove it so Python can find the real FA4 module: rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute ``` ::: {.callout-note} **Hopper (SM90) users**: The backward kernel is not yet included in the pip package. To use FA4 for training on Hopper, install from source using the instructions above. ::: ::: {.callout-warning} FA4 only supports head dimensions up to 128 (`d ≤ 128`). The DeepSeek shape `(192, 128)` is also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions and falls back to FA2/3. ::: ### AMD Requirements: ROCm 6.0 and above. See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support). ### Flex Attention ```yaml attn_implementation: flex_attention torch_compile: true # recommended ``` Requires torch ≥ 2.6. See [PyTorch docs](https://pytorch.org/blog/flexattention/). ### SageAttention Requirements: Ampere, Ada, or Hopper GPUs. ```yaml attn_implementation: sage ``` ```bash pip install sageattention==2.2.0 --no-build-isolation ``` ::: {.callout-warning} Only LoRA/QLoRA recommended. Full finetuning has been observed to drop loss to 0. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198). ::: For more details: [Sage Attention](https://github.com/thu-ml/SageAttention). ### xFormers ```yaml attn_implementation: xformers ``` ::: {.callout-tip} Recommended for Turing GPUs or below (e.g. Colab T4). ::: ### Shifted Sparse Attention ::: {.callout-warning} Planned for deprecation. Prefer one of the backends above. ::: Requirements: LLaMA model architecture. Loaded as FA2 under the hood and patched to implement shifted-sparse attention. Does not support sample packing. ```yaml attn_implementation: s2 ``` ### FP8 torchao low-precision attention. Loaded as SDPA and patched post-load. Requirements: SM90+ (Hopper/Blackwell), PyTorch ≥ 2.11, torchao ≥ 0.17, flash-attn with FA3. KV caching must be disabled. ```yaml attn_implementation: fp8 ``` ### Hub kernels ```yaml attn_implementation: kernels-community/flash-attn3 ``` Passed through to `transformers`; axolotl does not install the kernel itself. For recognized hub paths the capability flags are set automatically; for arbitrary paths axolotl uses conservative defaults (`attn_supports_packing=False`, `attn_uses_flash_lib=False`). ## Migrating from legacy boolean flags The following legacy config fields are **deprecated** and will be removed in a future release. Each emits a `DeprecationWarning` when set and is stripped from the validated config. | Legacy | Canonical | |---|---| | `flash_attention: true` | `attn_implementation: flash_attention_2` | | `sdp_attention: true` | `attn_implementation: sdpa` | | `xformers_attention: true` | `attn_implementation: xformers` | | `flex_attention: true` | `attn_implementation: flex_attention` | | `sage_attention: true` | `attn_implementation: sage` | | `s2_attention: true` | `attn_implementation: s2` | | `eager_attention: true` | `attn_implementation: eager` | Combining `attn_implementation` with a legacy flag (e.g. `attn_implementation: flash_attention_2` **and** `flash_attention: true`) raises — pick one. ::: {.callout-note} Existing example configs under `examples/` still use the legacy flags. They continue to work with a deprecation warning; they will be migrated in a follow-up pass. :::