feat: add FA4 (#3481)
* feat: add FA4 * chore: update docs * fix: recommend FA4 for those with compatible devices * fix: adjust import check and add head_dim check * chore: add limitation to doc * fix: log warning and quit if cannot import validator * chore: simplify * fix: add caveat with FA2 shadow dir
This commit is contained in:
@@ -13,9 +13,10 @@ sdp_attention: true
|
||||
|
||||
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
|
||||
## Flash Attention 2
|
||||
## Flash Attention
|
||||
|
||||
Uses efficient kernels to compute attention.
|
||||
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
|
||||
based on your installed packages and GPU.
|
||||
|
||||
```yaml
|
||||
flash_attention: true
|
||||
@@ -23,11 +24,9 @@ flash_attention: true
|
||||
|
||||
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
|
||||
|
||||
### Nvidia
|
||||
### Flash Attention 2
|
||||
|
||||
Requirements: Ampere, Ada, or Hopper GPUs
|
||||
|
||||
Note: For Turing GPUs or lower, please use other attention methods.
|
||||
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
|
||||
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
@@ -35,11 +34,12 @@ pip install flash-attn --no-build-isolation
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
|
||||
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl.
|
||||
Alternatively, try reinstall or downgrade a version.
|
||||
|
||||
:::
|
||||
|
||||
#### Flash Attention 3
|
||||
### Flash Attention 3
|
||||
|
||||
Requirements: Hopper only and CUDA 12.8 (recommended)
|
||||
|
||||
@@ -50,6 +50,44 @@ cd flash-attention/hopper
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
### Flash Attention 4
|
||||
|
||||
Requirements: Hopper or Blackwell GPUs
|
||||
|
||||
```bash
|
||||
pip install flash-attn-4
|
||||
```
|
||||
|
||||
Or from source:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||
cd flash-attention/flash_attn/cute
|
||||
|
||||
pip install -e .
|
||||
|
||||
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
|
||||
# Remove it so Python can find the real FA4 module:
|
||||
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
**Hopper (SM90) users**: The backward kernel is not yet included in the pip package. To use FA4
|
||||
for training on Hopper, install from source using the instructions above.
|
||||
|
||||
:::
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
FA4 only supports head dimensions up to 128 (`d ≤ 128`). The DeepSeek shape `(192, 128)` is
|
||||
also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions
|
||||
and falls back to FA2/3.
|
||||
|
||||
:::
|
||||
|
||||
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
|
||||
|
||||
### AMD
|
||||
|
||||
Requirements: ROCm 6.0 and above.
|
||||
|
||||
Reference in New Issue
Block a user