feat: add FA4 (#3481)
* feat: add FA4 * chore: update docs * fix: recommend FA4 for those with compatible devices * fix: adjust import check and add head_dim check * chore: add limitation to doc * fix: log warning and quit if cannot import validator * chore: simplify * fix: add caveat with FA2 shadow dir
This commit is contained in:
@@ -75,7 +75,7 @@ Features:
|
||||
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support.
|
||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
|
||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [SageAttention](https://github.com/thu-ml/SageAttention), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention 2/3/4](https://docs.axolotl.ai/docs/attention.html#flash-attention), [Xformers](https://docs.axolotl.ai/docs/attention.html#xformers), [Flex Attention](https://docs.axolotl.ai/docs/attention.html#flex-attention), [SageAttention](https://docs.axolotl.ai/docs/attention.html#sageattention), [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels), [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
||||
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
||||
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
||||
|
||||
|
||||
@@ -13,9 +13,10 @@ sdp_attention: true
|
||||
|
||||
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
|
||||
## Flash Attention 2
|
||||
## Flash Attention
|
||||
|
||||
Uses efficient kernels to compute attention.
|
||||
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
|
||||
based on your installed packages and GPU.
|
||||
|
||||
```yaml
|
||||
flash_attention: true
|
||||
@@ -23,11 +24,9 @@ flash_attention: true
|
||||
|
||||
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
|
||||
|
||||
### Nvidia
|
||||
### Flash Attention 2
|
||||
|
||||
Requirements: Ampere, Ada, or Hopper GPUs
|
||||
|
||||
Note: For Turing GPUs or lower, please use other attention methods.
|
||||
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
|
||||
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
@@ -35,11 +34,12 @@ pip install flash-attn --no-build-isolation
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
|
||||
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl.
|
||||
Alternatively, try reinstall or downgrade a version.
|
||||
|
||||
:::
|
||||
|
||||
#### Flash Attention 3
|
||||
### Flash Attention 3
|
||||
|
||||
Requirements: Hopper only and CUDA 12.8 (recommended)
|
||||
|
||||
@@ -50,6 +50,44 @@ cd flash-attention/hopper
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
### Flash Attention 4
|
||||
|
||||
Requirements: Hopper or Blackwell GPUs
|
||||
|
||||
```bash
|
||||
pip install flash-attn-4
|
||||
```
|
||||
|
||||
Or from source:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||
cd flash-attention/flash_attn/cute
|
||||
|
||||
pip install -e .
|
||||
|
||||
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
|
||||
# Remove it so Python can find the real FA4 module:
|
||||
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
**Hopper (SM90) users**: The backward kernel is not yet included in the pip package. To use FA4
|
||||
for training on Hopper, install from source using the instructions above.
|
||||
|
||||
:::
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
FA4 only supports head dimensions up to 128 (`d ≤ 128`). The DeepSeek shape `(192, 128)` is
|
||||
also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions
|
||||
and falls back to FA2/3.
|
||||
|
||||
:::
|
||||
|
||||
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
|
||||
|
||||
### AMD
|
||||
|
||||
Requirements: ROCm 6.0 and above.
|
||||
|
||||
@@ -99,6 +99,7 @@ class PatchManager:
|
||||
self._apply_flash_attention_patches()
|
||||
self._apply_chunked_cross_entropy_patch()
|
||||
self._apply_sageattn_patches()
|
||||
self._apply_flash_attn_4_patches()
|
||||
self._apply_fsdp_patches()
|
||||
self._apply_adapter_patches()
|
||||
self._apply_model_specific_patches()
|
||||
@@ -228,6 +229,15 @@ class PatchManager:
|
||||
|
||||
patch_sageattn()
|
||||
|
||||
def _apply_flash_attn_4_patches(self):
|
||||
"""Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+."""
|
||||
if not self.cfg.flash_attention:
|
||||
return
|
||||
|
||||
from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4
|
||||
|
||||
patch_flash_attn_4(self.model_config)
|
||||
|
||||
def _apply_model_specific_patches(self):
|
||||
"""Apply patches specific to model architectures."""
|
||||
if (
|
||||
|
||||
104
src/axolotl/monkeypatch/attention/flash_attn_4.py
Normal file
104
src/axolotl/monkeypatch/attention/flash_attn_4.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Transparently upgrade FA2 to FA4 when available on SM90+ hardware."""
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _get_head_dims(model_config):
|
||||
"""Extract (head_dim, head_dim_v) from a model config.
|
||||
|
||||
Handles composite models (e.g. Qwen3.5 VL) via text_config and
|
||||
MLA models (DeepSeek/Kimi) that have separate Q/V head dimensions.
|
||||
"""
|
||||
cfg = model_config
|
||||
if hasattr(cfg, "text_config"):
|
||||
cfg = cfg.text_config
|
||||
|
||||
# MLA models: Q head_dim = qk_nope + qk_rope, V head_dim = v_head_dim
|
||||
if hasattr(cfg, "qk_nope_head_dim") and hasattr(cfg, "qk_rope_head_dim"):
|
||||
head_dim = cfg.qk_nope_head_dim + cfg.qk_rope_head_dim
|
||||
head_dim_v = getattr(cfg, "v_head_dim", head_dim)
|
||||
return head_dim, head_dim_v
|
||||
|
||||
# Standard models
|
||||
if hasattr(cfg, "head_dim"):
|
||||
return cfg.head_dim, cfg.head_dim
|
||||
if hasattr(cfg, "hidden_size") and hasattr(cfg, "num_attention_heads"):
|
||||
head_dim = cfg.hidden_size // cfg.num_attention_heads
|
||||
return head_dim, head_dim
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def patch_flash_attn_4(model_config=None):
|
||||
"""Patch _lazy_imports to redirect FA2 imports to FA4 if available on supported hardware."""
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
# Matches flash_attn/cute/interface.py: arch / 10 in [9, 10, 11]
|
||||
if major not in (9, 10, 11):
|
||||
return
|
||||
|
||||
try:
|
||||
from flash_attn.cute import ( # noqa: F401
|
||||
flash_attn_func,
|
||||
flash_attn_varlen_func,
|
||||
)
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
"Flash Attention 4 is available for your GPU and offers faster training speeds. "
|
||||
"To enable: pip install flash-attn-4"
|
||||
)
|
||||
return
|
||||
|
||||
# Validate head dimensions against FA4's own constraints
|
||||
head_dim = None
|
||||
if model_config is not None:
|
||||
head_dim, head_dim_v = _get_head_dims(model_config)
|
||||
if head_dim is not None:
|
||||
try:
|
||||
from flash_attn.cute.interface import _validate_head_dims
|
||||
except ImportError:
|
||||
LOG.warning(
|
||||
"Could not import _validate_head_dims from flash_attn.cute.interface, "
|
||||
"unable to verify head dimension compatibility, falling back to FA2"
|
||||
)
|
||||
return
|
||||
|
||||
# alignment = 16 // element_size; bf16/fp16 = 2 bytes -> alignment = 8
|
||||
alignment = 8
|
||||
try:
|
||||
_validate_head_dims(head_dim, head_dim_v, major, alignment)
|
||||
except AssertionError as exc:
|
||||
LOG.warning(
|
||||
"Model head dimensions not supported by FA4, "
|
||||
"falling back to FA2: %s",
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
import transformers.modeling_flash_attention_utils as fa_utils
|
||||
|
||||
if getattr(fa_utils._lazy_imports, "_axolotl_patched", False):
|
||||
return
|
||||
|
||||
def _patched_lazy_imports(
|
||||
implementation, attention_wrapper=None, allow_all_kernels=False
|
||||
):
|
||||
return (
|
||||
flash_attn_func,
|
||||
flash_attn_varlen_func,
|
||||
fa_utils._pad_input,
|
||||
fa_utils._unpad_input,
|
||||
)
|
||||
|
||||
_patched_lazy_imports._axolotl_patched = True
|
||||
fa_utils._lazy_imports = _patched_lazy_imports
|
||||
LOG.info(
|
||||
"Flash Attention 4 enabled (head_dim=%s)",
|
||||
head_dim if model_config else "unknown",
|
||||
)
|
||||
Reference in New Issue
Block a user