* feat: add sageattention * feat: call path on pre model load * fix: patch to use register to correct var * fix: add strict check import at start * chore: fix comments * chore: refactor * feat: add capability check * fix: missed underscore * fix: let sageattention use FA backend in transformers * feat: update sage attention for attention mask and position ids * feat: allow sample packing but add warning without packing * fix: loss hitting 0 with packing and attention mask note * feat: downcast embeds if sage attention too * feat: add config validation * feat: add attention docs * chore: docs
141 lines
2.7 KiB
Plaintext
141 lines
2.7 KiB
Plaintext
---
|
|
title: Attention
|
|
description: Supported attention modules in Axolotl
|
|
---
|
|
|
|
## SDP Attention
|
|
|
|
This is the default built-in attention in PyTorch.
|
|
|
|
```yaml
|
|
sdp_attention: true
|
|
```
|
|
|
|
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
|
|
|
## Flash Attention 2
|
|
|
|
Uses efficient kernels to compute attention.
|
|
|
|
```yaml
|
|
flash_attention: true
|
|
```
|
|
|
|
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
|
|
|
|
### Nvidia
|
|
|
|
Requirements: Ampere, Ada, or Hopper GPUs
|
|
|
|
Note: For Turing GPUs or lower, please use other attention methods.
|
|
|
|
```bash
|
|
pip install flash-attn --no-build-isolation
|
|
```
|
|
|
|
::: {.callout-tip}
|
|
|
|
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
|
|
|
|
:::
|
|
|
|
#### Flash Attention 3
|
|
|
|
Requirements: Hopper only and CUDA 12.8 (recommended)
|
|
|
|
```bash
|
|
git clone https://github.com/Dao-AILab/flash-attention.git
|
|
cd flash-attention/hopper
|
|
|
|
python setup.py install
|
|
```
|
|
|
|
### AMD
|
|
|
|
Requirements: ROCm 6.0 and above.
|
|
|
|
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
|
|
|
|
## Flex Attention
|
|
|
|
A flexible PyTorch API for attention used in combination with `torch.compile`.
|
|
|
|
```yaml
|
|
flex_attention: true
|
|
|
|
# recommended
|
|
torch_compile: true
|
|
```
|
|
|
|
::: {.callout-note}
|
|
|
|
We recommend using latest stable version of PyTorch for best performance.
|
|
|
|
:::
|
|
|
|
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
|
|
|
|
## SageAttention
|
|
|
|
Attention kernels with QK Int8 and PV FP16 accumulator.
|
|
|
|
```yaml
|
|
sage_attention: true
|
|
```
|
|
|
|
Requirements: Ampere, Ada, or Hopper GPUs
|
|
|
|
```bash
|
|
pip install sageattention==2.2.0 --no-build-isolation
|
|
```
|
|
|
|
::: {.callout-warning}
|
|
|
|
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
|
|
|
|
:::
|
|
|
|
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
|
|
|
|
::: {.callout-note}
|
|
|
|
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
|
|
|
|
:::
|
|
|
|
|
|
## xFormers
|
|
|
|
```yaml
|
|
xformers_attention: true
|
|
```
|
|
|
|
::: {.callout-tip}
|
|
|
|
We recommend using with Turing GPUs or below (such as on Colab).
|
|
|
|
:::
|
|
|
|
For more details: [xFormers](https://github.com/facebookresearch/xformers)
|
|
|
|
## Shifted Sparse Attention
|
|
|
|
::: {.callout-warning}
|
|
|
|
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
|
|
|
|
:::
|
|
|
|
Requirements: LLaMA model architecture
|
|
|
|
```yaml
|
|
flash_attention: true
|
|
s2_attention: true
|
|
```
|
|
|
|
::: {.callout-tip}
|
|
|
|
No sample packing support!
|
|
|
|
:::
|