* upgrade to torchao 0.17.0 * chore: lint * refactor attention handling * replace legacy attention boolean flags with capability properties Replace checks with capability-based properties derived from attn_implementation This separates three concerns that were conflated under flash_attention: 1. Backend selection -> attn_implementation enum 2. Packing capability -> attn_supports_packing property 3. Flash-attn library dependency -> attn_uses_flash_lib property * compute attn capability flags in normalizer instead of properties * make attn_implementation the single source of truth * move attention-dependent validators to mode=after * migrate remaining consumers to canonical attn_implementation * expand attention tests + rewrite docs * migrate example configs to canonical attn_implementation * update doc snippets + reject gemma4-hybrid with non-FA2 backend * remove dead gemma4 branch in _set_attention_config * fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests * drop "Phase 2" naming from attn-implementation tests * regroup attn_implementation tests by feature concern * clean up verbose comments and remove MD Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai> * fix(collator): pass return_dict=True at apply_chat_template top level for transformers 5.x In transformers 5.x, ProcessorMixin.apply_chat_template gained its own `return_dict` parameter (defaulting to False). When return_dict=False and tokenize=True the method returns out["input_ids"] directly — a 2-D tensor — rather than the full BatchFeature dict. The old code placed `return_dict=True` inside processor_kwargs. In transformers 5.x those kwargs are forwarded to the underlying processor call self(...) where _merge_kwargs silently ignores any key not present in MllamaProcessorKwargs (emitting a warning). The outer return_dict therefore stayed False, apply_chat_template returned the raw input_ids tensor, and the subsequent `batch["input_ids"]` attempted to index a 2-D tensor with the 9-character string "input_ids", producing: IndexError: too many indices for tensor of dimension 2 The fix is to pass return_dict=True as a top-level keyword argument to apply_chat_template (where it is actually consumed) and remove it from processor_kwargs (where it was silently dropped). No version guard is needed: transformers is pinned to ==5.5.4 in pyproject.toml. Adds a unit-level regression test (tests/test_mm_chat_collator.py) that mocks the processor to return a raw tensor when apply_chat_template is called without top-level return_dict=True, verifying the four invariants: process_rows returns a dict, input_ids is 2-D, labels is 2-D, and apply_chat_template receives return_dict=True as a top-level kwarg. Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_multimodal_dataset Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_text_only_dataset Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai> * fix(collator): process_rows returns dict (BatchFeature) shape Two related changes for the multimodal chat collator under transformers 5.x: 1. Wrap apply_chat_template result in dict(...) so process_rows returns a plain dict rather than a BatchFeature instance. BatchFeature is a Mapping but not a dict; downstream code that did batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"]) would index on a tensor when the result wasn't dict-shaped, raising IndexError: too many indices for tensor of dimension 2 2. Soften the regression test's contract from `dict` to `Mapping` so it exercises the actual semantic guarantee (key/value access) rather than the implementation detail (dict vs BatchFeature). Test guards against the original transformers 5.x breakage where apply_chat_template's return_dict default went from True to False. Includes regression test under tests/test_mm_chat_collator.py. Bug surfaced via swarm dispatch task_01KQHPNAYD8XARSNSDJVW1GPF6 against attn-implementation-refactor; squash-merged from agent commits 4de886fd + dc9fcf4f. Signed-off-by: Wing Lian <wing@axolotl.ai> --------- Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
243 lines
6.7 KiB
Plaintext
243 lines
6.7 KiB
Plaintext
---
|
|
title: Attention
|
|
description: Supported attention modules in Axolotl
|
|
---
|
|
|
|
Axolotl routes attention via a single config field:
|
|
|
|
```yaml
|
|
attn_implementation: <backend>
|
|
```
|
|
|
|
`attn_implementation` is passed through to `transformers` verbatim (via
|
|
`model.config._attn_implementation`). Accepted values are the HF-native
|
|
backends, axolotl-registered backends, or a hub-kernel path.
|
|
|
|
## Backends
|
|
|
|
| `attn_implementation` | Description |
|
|
|---|---|
|
|
| `eager` | Plain PyTorch attention. No packing support. |
|
|
| `sdpa` | PyTorch `scaled_dot_product_attention`. No packing support. |
|
|
| `flash_attention_2` | Dao-AILab Flash Attention 2. |
|
|
| `flash_attention_3` | Dao-AILab Flash Attention 3 (Hopper+). |
|
|
| `flex_attention` | Torch Flex Attention (requires torch ≥ 2.6). |
|
|
| `xformers` | xFormers memory-efficient attention. |
|
|
| `sage` | SageAttention (QK int8 / PV fp16). |
|
|
| `s2` | Shifted-Sparse Attention (LLaMA only, FA2 under the hood). |
|
|
| `fp8` | torchao FP8 low-precision attention (requires SM90+, torch ≥ 2.11). Loaded as SDPA and patched post-load. |
|
|
| `kernels-community/flash-attn3` | HF hub FA3 kernel. |
|
|
| `kernels-community/sage-attention` | HF hub SageAttention kernel. |
|
|
| Other `<org>/<name>` path | Any hub-kernel path supported by `transformers`. |
|
|
|
|
Short-form aliases (`flash`, `fa2`, `flex`, `sdp`, etc.) are **not accepted** —
|
|
set the canonical name above.
|
|
|
|
### Capability flags
|
|
|
|
Axolotl derives three boolean capability flags from `attn_implementation` and
|
|
exposes them on the validated config:
|
|
|
|
- `cfg.attn_supports_packing` — backend supports varlen sample packing via
|
|
`position_ids`. Gates multipack patches and `sample_packing_drop_attention_mask`.
|
|
- `cfg.attn_uses_flash_lib` — backend needs the `flash_attn` (Dao-AILab)
|
|
monkeypatches (FA4 auto, LLaMA flash hijack, ring-FA).
|
|
- `cfg.attn_needs_dtype_cast` — backend requires fp16/bf16 embeddings
|
|
(everything except `eager` and `sdpa`).
|
|
|
|
These are **computed** — they cannot be overridden from YAML.
|
|
|
|
## Per-backend notes
|
|
|
|
### SDPA
|
|
|
|
Default PyTorch attention. See
|
|
[PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
|
|
|
```yaml
|
|
attn_implementation: sdpa
|
|
```
|
|
|
|
### Flash Attention
|
|
|
|
Axolotl supports FA2, FA3, and FA4. The best available version is used
|
|
automatically based on your installed packages and GPU.
|
|
|
|
```yaml
|
|
attn_implementation: flash_attention_2 # or flash_attention_3
|
|
```
|
|
|
|
#### Flash Attention 2
|
|
|
|
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
|
|
|
|
```bash
|
|
pip install flash-attn --no-build-isolation
|
|
```
|
|
|
|
::: {.callout-tip}
|
|
|
|
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl.
|
|
Alternatively, try reinstall or downgrade a version.
|
|
|
|
:::
|
|
|
|
#### Flash Attention 3
|
|
|
|
Requirements: Hopper only and CUDA 12.8 (recommended)
|
|
|
|
```bash
|
|
git clone https://github.com/Dao-AILab/flash-attention.git
|
|
cd flash-attention/hopper
|
|
python setup.py install
|
|
```
|
|
|
|
#### Flash Attention 4
|
|
|
|
Requirements: Hopper or Blackwell GPUs. Auto-applied when `attn_uses_flash_lib`
|
|
is true and FA4 is importable.
|
|
|
|
FA4 is still a pre-release on PyPI, so `--pre` is required:
|
|
|
|
```bash
|
|
pip install --pre flash-attn-4
|
|
```
|
|
|
|
Or from source:
|
|
|
|
```bash
|
|
git clone https://github.com/Dao-AILab/flash-attention.git
|
|
cd flash-attention/flash_attn/cute
|
|
pip install -e .
|
|
|
|
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
|
|
# Remove it so Python can find the real FA4 module:
|
|
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute
|
|
```
|
|
|
|
::: {.callout-note}
|
|
|
|
**Hopper (SM90) users**: The backward kernel is not yet included in the pip package. To use FA4
|
|
for training on Hopper, install from source using the instructions above.
|
|
|
|
:::
|
|
|
|
::: {.callout-warning}
|
|
|
|
FA4 only supports head dimensions up to 128 (`d ≤ 128`). The DeepSeek shape `(192, 128)` is
|
|
also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions
|
|
and falls back to FA2/3.
|
|
|
|
:::
|
|
|
|
### AMD
|
|
|
|
Requirements: ROCm 6.0 and above. See
|
|
[Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
|
|
|
|
### Flex Attention
|
|
|
|
```yaml
|
|
attn_implementation: flex_attention
|
|
torch_compile: true # recommended
|
|
```
|
|
|
|
Requires torch ≥ 2.6. See [PyTorch docs](https://pytorch.org/blog/flexattention/).
|
|
|
|
### SageAttention
|
|
|
|
Requirements: Ampere, Ada, or Hopper GPUs.
|
|
|
|
```yaml
|
|
attn_implementation: sage
|
|
```
|
|
|
|
```bash
|
|
pip install sageattention==2.2.0 --no-build-isolation
|
|
```
|
|
|
|
::: {.callout-warning}
|
|
|
|
Only LoRA/QLoRA recommended. Full finetuning has been observed to drop loss to 0. See
|
|
[GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
|
|
|
|
:::
|
|
|
|
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention).
|
|
|
|
### xFormers
|
|
|
|
```yaml
|
|
attn_implementation: xformers
|
|
```
|
|
|
|
::: {.callout-tip}
|
|
|
|
Recommended for Turing GPUs or below (e.g. Colab T4).
|
|
|
|
:::
|
|
|
|
### Shifted Sparse Attention
|
|
|
|
::: {.callout-warning}
|
|
|
|
Planned for deprecation. Prefer one of the backends above.
|
|
|
|
:::
|
|
|
|
Requirements: LLaMA model architecture. Loaded as FA2 under the hood and
|
|
patched to implement shifted-sparse attention. Does not support sample packing.
|
|
|
|
```yaml
|
|
attn_implementation: s2
|
|
```
|
|
|
|
### FP8
|
|
|
|
torchao low-precision attention. Loaded as SDPA and patched post-load.
|
|
|
|
Requirements: SM90+ (Hopper/Blackwell), PyTorch ≥ 2.11, torchao ≥ 0.17,
|
|
flash-attn with FA3. KV caching must be disabled.
|
|
|
|
```yaml
|
|
attn_implementation: fp8
|
|
```
|
|
|
|
### Hub kernels
|
|
|
|
```yaml
|
|
attn_implementation: kernels-community/flash-attn3
|
|
```
|
|
|
|
Passed through to `transformers`; axolotl does not install the kernel itself.
|
|
For recognized hub paths the capability flags are set automatically; for
|
|
arbitrary paths axolotl uses conservative defaults (`attn_supports_packing=False`,
|
|
`attn_uses_flash_lib=False`).
|
|
|
|
## Migrating from legacy boolean flags
|
|
|
|
The following legacy config fields are **deprecated** and will be removed in a
|
|
future release. Each emits a `DeprecationWarning` when set and is stripped from
|
|
the validated config.
|
|
|
|
| Legacy | Canonical |
|
|
|---|---|
|
|
| `flash_attention: true` | `attn_implementation: flash_attention_2` |
|
|
| `sdp_attention: true` | `attn_implementation: sdpa` |
|
|
| `xformers_attention: true` | `attn_implementation: xformers` |
|
|
| `flex_attention: true` | `attn_implementation: flex_attention` |
|
|
| `sage_attention: true` | `attn_implementation: sage` |
|
|
| `s2_attention: true` | `attn_implementation: s2` |
|
|
| `eager_attention: true` | `attn_implementation: eager` |
|
|
|
|
Combining `attn_implementation` with a legacy flag (e.g. `attn_implementation:
|
|
flash_attention_2` **and** `flash_attention: true`) raises — pick one.
|
|
|
|
::: {.callout-note}
|
|
|
|
Existing example configs under `examples/` still use the legacy flags. They
|
|
continue to work with a deprecation warning; they will be migrated in a
|
|
follow-up pass.
|
|
|
|
:::
|