From 7da5f943795f1e732291229fc779a8011fce0486 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 16 Mar 2026 11:13:18 +0700 Subject: [PATCH] feat: add FA4 (#3481) * feat: add FA4 * chore: update docs * fix: recommend FA4 for those with compatible devices * fix: adjust import check and add head_dim check * chore: add limitation to doc * fix: log warning and quit if cannot import validator * chore: simplify * fix: add caveat with FA2 shadow dir --- README.md | 2 +- docs/attention.qmd | 54 +++++++-- src/axolotl/loaders/patch_manager.py | 10 ++ .../monkeypatch/attention/flash_attn_4.py | 104 ++++++++++++++++++ 4 files changed, 161 insertions(+), 9 deletions(-) create mode 100644 src/axolotl/monkeypatch/attention/flash_attn_4.py diff --git a/README.md b/README.md index 594b06156..f10e08b42 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ Features: - **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support. - **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM). - **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference. -- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [SageAttention](https://github.com/thu-ml/SageAttention), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more! +- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention 2/3/4](https://docs.axolotl.ai/docs/attention.html#flash-attention), [Xformers](https://docs.axolotl.ai/docs/attention.html#xformers), [Flex Attention](https://docs.axolotl.ai/docs/attention.html#flex-attention), [SageAttention](https://docs.axolotl.ai/docs/attention.html#sageattention), [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels), [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more! - **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets. - **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware. diff --git a/docs/attention.qmd b/docs/attention.qmd index 21004277e..771299a29 100644 --- a/docs/attention.qmd +++ b/docs/attention.qmd @@ -13,9 +13,10 @@ sdp_attention: true For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) -## Flash Attention 2 +## Flash Attention -Uses efficient kernels to compute attention. +Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically +based on your installed packages and GPU. ```yaml flash_attention: true @@ -23,11 +24,9 @@ flash_attention: true For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/) -### Nvidia +### Flash Attention 2 -Requirements: Ampere, Ada, or Hopper GPUs - -Note: For Turing GPUs or lower, please use other attention methods. +Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported) ```bash pip install flash-attn --no-build-isolation @@ -35,11 +34,12 @@ pip install flash-attn --no-build-isolation ::: {.callout-tip} -If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version. +If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. +Alternatively, try reinstall or downgrade a version. ::: -#### Flash Attention 3 +### Flash Attention 3 Requirements: Hopper only and CUDA 12.8 (recommended) @@ -50,6 +50,44 @@ cd flash-attention/hopper python setup.py install ``` +### Flash Attention 4 + +Requirements: Hopper or Blackwell GPUs + +```bash +pip install flash-attn-4 +``` + +Or from source: + +```bash +git clone https://github.com/Dao-AILab/flash-attention.git +cd flash-attention/flash_attn/cute + +pip install -e . + +# FA2's flash_attn package includes a cute/ stub that shadows FA4. +# Remove it so Python can find the real FA4 module: +rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute +``` + +::: {.callout-note} + +**Hopper (SM90) users**: The backward kernel is not yet included in the pip package. To use FA4 +for training on Hopper, install from source using the instructions above. + +::: + +::: {.callout-warning} + +FA4 only supports head dimensions up to 128 (`d ≤ 128`). The DeepSeek shape `(192, 128)` is +also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions +and falls back to FA2/3. + +::: + +For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute) + ### AMD Requirements: ROCm 6.0 and above. diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 857a2f76f..5874c940b 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -99,6 +99,7 @@ class PatchManager: self._apply_flash_attention_patches() self._apply_chunked_cross_entropy_patch() self._apply_sageattn_patches() + self._apply_flash_attn_4_patches() self._apply_fsdp_patches() self._apply_adapter_patches() self._apply_model_specific_patches() @@ -228,6 +229,15 @@ class PatchManager: patch_sageattn() + def _apply_flash_attn_4_patches(self): + """Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+.""" + if not self.cfg.flash_attention: + return + + from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4 + + patch_flash_attn_4(self.model_config) + def _apply_model_specific_patches(self): """Apply patches specific to model architectures.""" if ( diff --git a/src/axolotl/monkeypatch/attention/flash_attn_4.py b/src/axolotl/monkeypatch/attention/flash_attn_4.py new file mode 100644 index 000000000..5ebc93670 --- /dev/null +++ b/src/axolotl/monkeypatch/attention/flash_attn_4.py @@ -0,0 +1,104 @@ +"""Transparently upgrade FA2 to FA4 when available on SM90+ hardware.""" + +import torch + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def _get_head_dims(model_config): + """Extract (head_dim, head_dim_v) from a model config. + + Handles composite models (e.g. Qwen3.5 VL) via text_config and + MLA models (DeepSeek/Kimi) that have separate Q/V head dimensions. + """ + cfg = model_config + if hasattr(cfg, "text_config"): + cfg = cfg.text_config + + # MLA models: Q head_dim = qk_nope + qk_rope, V head_dim = v_head_dim + if hasattr(cfg, "qk_nope_head_dim") and hasattr(cfg, "qk_rope_head_dim"): + head_dim = cfg.qk_nope_head_dim + cfg.qk_rope_head_dim + head_dim_v = getattr(cfg, "v_head_dim", head_dim) + return head_dim, head_dim_v + + # Standard models + if hasattr(cfg, "head_dim"): + return cfg.head_dim, cfg.head_dim + if hasattr(cfg, "hidden_size") and hasattr(cfg, "num_attention_heads"): + head_dim = cfg.hidden_size // cfg.num_attention_heads + return head_dim, head_dim + + return None, None + + +def patch_flash_attn_4(model_config=None): + """Patch _lazy_imports to redirect FA2 imports to FA4 if available on supported hardware.""" + if not torch.cuda.is_available(): + return + + major, _ = torch.cuda.get_device_capability() + # Matches flash_attn/cute/interface.py: arch / 10 in [9, 10, 11] + if major not in (9, 10, 11): + return + + try: + from flash_attn.cute import ( # noqa: F401 + flash_attn_func, + flash_attn_varlen_func, + ) + except ImportError: + LOG.info( + "Flash Attention 4 is available for your GPU and offers faster training speeds. " + "To enable: pip install flash-attn-4" + ) + return + + # Validate head dimensions against FA4's own constraints + head_dim = None + if model_config is not None: + head_dim, head_dim_v = _get_head_dims(model_config) + if head_dim is not None: + try: + from flash_attn.cute.interface import _validate_head_dims + except ImportError: + LOG.warning( + "Could not import _validate_head_dims from flash_attn.cute.interface, " + "unable to verify head dimension compatibility, falling back to FA2" + ) + return + + # alignment = 16 // element_size; bf16/fp16 = 2 bytes -> alignment = 8 + alignment = 8 + try: + _validate_head_dims(head_dim, head_dim_v, major, alignment) + except AssertionError as exc: + LOG.warning( + "Model head dimensions not supported by FA4, " + "falling back to FA2: %s", + exc, + ) + return + + import transformers.modeling_flash_attention_utils as fa_utils + + if getattr(fa_utils._lazy_imports, "_axolotl_patched", False): + return + + def _patched_lazy_imports( + implementation, attention_wrapper=None, allow_all_kernels=False + ): + return ( + flash_attn_func, + flash_attn_varlen_func, + fa_utils._pad_input, + fa_utils._unpad_input, + ) + + _patched_lazy_imports._axolotl_patched = True + fa_utils._lazy_imports = _patched_lazy_imports + LOG.info( + "Flash Attention 4 enabled (head_dim=%s)", + head_dim if model_config else "unknown", + )