* fix: clarify incompat * fix: transformers api change upstream * fix: add pre prop * feat: add examples * chore: cleanup * chore: update readme
181 lines
3.8 KiB
Plaintext
181 lines
3.8 KiB
Plaintext
---
|
|
title: Attention
|
|
description: Supported attention modules in Axolotl
|
|
---
|
|
|
|
## SDP Attention
|
|
|
|
This is the default built-in attention in PyTorch.
|
|
|
|
```yaml
|
|
sdp_attention: true
|
|
```
|
|
|
|
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
|
|
|
## Flash Attention
|
|
|
|
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
|
|
based on your installed packages and GPU.
|
|
|
|
```yaml
|
|
flash_attention: true
|
|
```
|
|
|
|
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
|
|
|
|
### Flash Attention 2
|
|
|
|
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
|
|
|
|
```bash
|
|
pip install flash-attn --no-build-isolation
|
|
```
|
|
|
|
::: {.callout-tip}
|
|
|
|
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl.
|
|
Alternatively, try reinstall or downgrade a version.
|
|
|
|
:::
|
|
|
|
### Flash Attention 3
|
|
|
|
Requirements: Hopper only and CUDA 12.8 (recommended)
|
|
|
|
```bash
|
|
git clone https://github.com/Dao-AILab/flash-attention.git
|
|
cd flash-attention/hopper
|
|
|
|
python setup.py install
|
|
```
|
|
|
|
### Flash Attention 4
|
|
|
|
Requirements: Hopper or Blackwell GPUs
|
|
|
|
FA4 is still a pre-release on PyPI, so `--pre` is required:
|
|
|
|
```bash
|
|
pip install --pre flash-attn-4
|
|
```
|
|
|
|
Or from source:
|
|
|
|
```bash
|
|
git clone https://github.com/Dao-AILab/flash-attention.git
|
|
cd flash-attention/flash_attn/cute
|
|
|
|
pip install -e .
|
|
|
|
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
|
|
# Remove it so Python can find the real FA4 module:
|
|
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute
|
|
```
|
|
|
|
::: {.callout-note}
|
|
|
|
**Hopper (SM90) users**: The backward kernel is not yet included in the pip package. To use FA4
|
|
for training on Hopper, install from source using the instructions above.
|
|
|
|
:::
|
|
|
|
::: {.callout-warning}
|
|
|
|
FA4 only supports head dimensions up to 128 (`d ≤ 128`). The DeepSeek shape `(192, 128)` is
|
|
also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions
|
|
and falls back to FA2/3.
|
|
|
|
:::
|
|
|
|
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
|
|
|
|
### AMD
|
|
|
|
Requirements: ROCm 6.0 and above.
|
|
|
|
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
|
|
|
|
## Flex Attention
|
|
|
|
A flexible PyTorch API for attention used in combination with `torch.compile`.
|
|
|
|
```yaml
|
|
flex_attention: true
|
|
|
|
# recommended
|
|
torch_compile: true
|
|
```
|
|
|
|
::: {.callout-note}
|
|
|
|
We recommend using latest stable version of PyTorch for best performance.
|
|
|
|
:::
|
|
|
|
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
|
|
|
|
## SageAttention
|
|
|
|
Attention kernels with QK Int8 and PV FP16 accumulator.
|
|
|
|
```yaml
|
|
sage_attention: true
|
|
```
|
|
|
|
Requirements: Ampere, Ada, or Hopper GPUs
|
|
|
|
```bash
|
|
pip install sageattention==2.2.0 --no-build-isolation
|
|
```
|
|
|
|
::: {.callout-warning}
|
|
|
|
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
|
|
|
|
:::
|
|
|
|
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
|
|
|
|
::: {.callout-note}
|
|
|
|
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
|
|
|
|
:::
|
|
|
|
|
|
## xFormers
|
|
|
|
```yaml
|
|
xformers_attention: true
|
|
```
|
|
|
|
::: {.callout-tip}
|
|
|
|
We recommend using with Turing GPUs or below (such as on Colab).
|
|
|
|
:::
|
|
|
|
For more details: [xFormers](https://github.com/facebookresearch/xformers)
|
|
|
|
## Shifted Sparse Attention
|
|
|
|
::: {.callout-warning}
|
|
|
|
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
|
|
|
|
:::
|
|
|
|
Requirements: LLaMA model architecture
|
|
|
|
```yaml
|
|
flash_attention: true
|
|
s2_attention: true
|
|
```
|
|
|
|
::: {.callout-tip}
|
|
|
|
No sample packing support!
|
|
|
|
:::
|