feat: add FA4 (#3481)

* feat: add FA4

* chore: update docs

* fix: recommend FA4 for those with compatible devices

* fix: adjust import check and add head_dim check

* chore: add limitation to doc

* fix: log warning and quit if cannot import validator

* chore: simplify

* fix: add caveat with FA2 shadow dir
This commit is contained in:
NanoCode012
2026-03-16 11:13:18 +07:00
committed by GitHub
parent 4a5876df7a
commit 7da5f94379
4 changed files with 161 additions and 9 deletions

View File

@@ -75,7 +75,7 @@ Features:
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support.
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [SageAttention](https://github.com/thu-ml/SageAttention), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention 2/3/4](https://docs.axolotl.ai/docs/attention.html#flash-attention), [Xformers](https://docs.axolotl.ai/docs/attention.html#xformers), [Flex Attention](https://docs.axolotl.ai/docs/attention.html#flex-attention), [SageAttention](https://docs.axolotl.ai/docs/attention.html#sageattention), [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels), [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.

View File

@@ -13,9 +13,10 @@ sdp_attention: true
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
## Flash Attention 2
## Flash Attention
Uses efficient kernels to compute attention.
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
based on your installed packages and GPU.
```yaml
flash_attention: true
@@ -23,11 +24,9 @@ flash_attention: true
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
### Nvidia
### Flash Attention 2
Requirements: Ampere, Ada, or Hopper GPUs
Note: For Turing GPUs or lower, please use other attention methods.
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
```bash
pip install flash-attn --no-build-isolation
@@ -35,11 +34,12 @@ pip install flash-attn --no-build-isolation
::: {.callout-tip}
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl.
Alternatively, try reinstall or downgrade a version.
:::
#### Flash Attention 3
### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
@@ -50,6 +50,44 @@ cd flash-attention/hopper
python setup.py install
```
### Flash Attention 4
Requirements: Hopper or Blackwell GPUs
```bash
pip install flash-attn-4
```
Or from source:
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/flash_attn/cute
pip install -e .
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
# Remove it so Python can find the real FA4 module:
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute
```
::: {.callout-note}
**Hopper (SM90) users**: The backward kernel is not yet included in the pip package. To use FA4
for training on Hopper, install from source using the instructions above.
:::
::: {.callout-warning}
FA4 only supports head dimensions up to 128 (`d ≤ 128`). The DeepSeek shape `(192, 128)` is
also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions
and falls back to FA2/3.
:::
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
### AMD
Requirements: ROCm 6.0 and above.

View File

@@ -99,6 +99,7 @@ class PatchManager:
self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch()
self._apply_sageattn_patches()
self._apply_flash_attn_4_patches()
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_model_specific_patches()
@@ -228,6 +229,15 @@ class PatchManager:
patch_sageattn()
def _apply_flash_attn_4_patches(self):
"""Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+."""
if not self.cfg.flash_attention:
return
from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4
patch_flash_attn_4(self.model_config)
def _apply_model_specific_patches(self):
"""Apply patches specific to model architectures."""
if (

View File

@@ -0,0 +1,104 @@
"""Transparently upgrade FA2 to FA4 when available on SM90+ hardware."""
import torch
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _get_head_dims(model_config):
"""Extract (head_dim, head_dim_v) from a model config.
Handles composite models (e.g. Qwen3.5 VL) via text_config and
MLA models (DeepSeek/Kimi) that have separate Q/V head dimensions.
"""
cfg = model_config
if hasattr(cfg, "text_config"):
cfg = cfg.text_config
# MLA models: Q head_dim = qk_nope + qk_rope, V head_dim = v_head_dim
if hasattr(cfg, "qk_nope_head_dim") and hasattr(cfg, "qk_rope_head_dim"):
head_dim = cfg.qk_nope_head_dim + cfg.qk_rope_head_dim
head_dim_v = getattr(cfg, "v_head_dim", head_dim)
return head_dim, head_dim_v
# Standard models
if hasattr(cfg, "head_dim"):
return cfg.head_dim, cfg.head_dim
if hasattr(cfg, "hidden_size") and hasattr(cfg, "num_attention_heads"):
head_dim = cfg.hidden_size // cfg.num_attention_heads
return head_dim, head_dim
return None, None
def patch_flash_attn_4(model_config=None):
"""Patch _lazy_imports to redirect FA2 imports to FA4 if available on supported hardware."""
if not torch.cuda.is_available():
return
major, _ = torch.cuda.get_device_capability()
# Matches flash_attn/cute/interface.py: arch / 10 in [9, 10, 11]
if major not in (9, 10, 11):
return
try:
from flash_attn.cute import ( # noqa: F401
flash_attn_func,
flash_attn_varlen_func,
)
except ImportError:
LOG.info(
"Flash Attention 4 is available for your GPU and offers faster training speeds. "
"To enable: pip install flash-attn-4"
)
return
# Validate head dimensions against FA4's own constraints
head_dim = None
if model_config is not None:
head_dim, head_dim_v = _get_head_dims(model_config)
if head_dim is not None:
try:
from flash_attn.cute.interface import _validate_head_dims
except ImportError:
LOG.warning(
"Could not import _validate_head_dims from flash_attn.cute.interface, "
"unable to verify head dimension compatibility, falling back to FA2"
)
return
# alignment = 16 // element_size; bf16/fp16 = 2 bytes -> alignment = 8
alignment = 8
try:
_validate_head_dims(head_dim, head_dim_v, major, alignment)
except AssertionError as exc:
LOG.warning(
"Model head dimensions not supported by FA4, "
"falling back to FA2: %s",
exc,
)
return
import transformers.modeling_flash_attention_utils as fa_utils
if getattr(fa_utils._lazy_imports, "_axolotl_patched", False):
return
def _patched_lazy_imports(
implementation, attention_wrapper=None, allow_all_kernels=False
):
return (
flash_attn_func,
flash_attn_varlen_func,
fa_utils._pad_input,
fa_utils._unpad_input,
)
_patched_lazy_imports._axolotl_patched = True
fa_utils._lazy_imports = _patched_lazy_imports
LOG.info(
"Flash Attention 4 enabled (head_dim=%s)",
head_dim if model_config else "unknown",
)