feat: add FA4 (#3481)

* feat: add FA4

* chore: update docs

* fix: recommend FA4 for those with compatible devices

* fix: adjust import check and add head_dim check

* chore: add limitation to doc

* fix: log warning and quit if cannot import validator

* chore: simplify

* fix: add caveat with FA2 shadow dir
This commit is contained in:
NanoCode012
2026-03-16 11:13:18 +07:00
committed by GitHub
parent 4a5876df7a
commit 7da5f94379
4 changed files with 161 additions and 9 deletions

View File

@@ -13,9 +13,10 @@ sdp_attention: true
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
## Flash Attention 2
## Flash Attention
Uses efficient kernels to compute attention.
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
based on your installed packages and GPU.
```yaml
flash_attention: true
@@ -23,11 +24,9 @@ flash_attention: true
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
### Nvidia
### Flash Attention 2
Requirements: Ampere, Ada, or Hopper GPUs
Note: For Turing GPUs or lower, please use other attention methods.
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
```bash
pip install flash-attn --no-build-isolation
@@ -35,11 +34,12 @@ pip install flash-attn --no-build-isolation
::: {.callout-tip}
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl.
Alternatively, try reinstall or downgrade a version.
:::
#### Flash Attention 3
### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
@@ -50,6 +50,44 @@ cd flash-attention/hopper
python setup.py install
```
### Flash Attention 4
Requirements: Hopper or Blackwell GPUs
```bash
pip install flash-attn-4
```
Or from source:
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/flash_attn/cute
pip install -e .
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
# Remove it so Python can find the real FA4 module:
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute
```
::: {.callout-note}
**Hopper (SM90) users**: The backward kernel is not yet included in the pip package. To use FA4
for training on Hopper, install from source using the instructions above.
:::
::: {.callout-warning}
FA4 only supports head dimensions up to 128 (`d ≤ 128`). The DeepSeek shape `(192, 128)` is
also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions
and falls back to FA2/3.
:::
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
### AMD
Requirements: ROCm 6.0 and above.