feat: add FA4 (#3481)

* feat: add FA4

* chore: update docs

* fix: recommend FA4 for those with compatible devices

* fix: adjust import check and add head_dim check

* chore: add limitation to doc

* fix: log warning and quit if cannot import validator

* chore: simplify

* fix: add caveat with FA2 shadow dir
This commit is contained in:
NanoCode012
2026-03-16 11:13:18 +07:00
committed by GitHub
parent 4a5876df7a
commit 7da5f94379
4 changed files with 161 additions and 9 deletions

View File

@@ -99,6 +99,7 @@ class PatchManager:
self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch()
self._apply_sageattn_patches()
self._apply_flash_attn_4_patches()
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_model_specific_patches()
@@ -228,6 +229,15 @@ class PatchManager:
patch_sageattn()
def _apply_flash_attn_4_patches(self):
"""Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+."""
if not self.cfg.flash_attention:
return
from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4
patch_flash_attn_4(self.model_config)
def _apply_model_specific_patches(self):
"""Apply patches specific to model architectures."""
if (

View File

@@ -0,0 +1,104 @@
"""Transparently upgrade FA2 to FA4 when available on SM90+ hardware."""
import torch
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _get_head_dims(model_config):
"""Extract (head_dim, head_dim_v) from a model config.
Handles composite models (e.g. Qwen3.5 VL) via text_config and
MLA models (DeepSeek/Kimi) that have separate Q/V head dimensions.
"""
cfg = model_config
if hasattr(cfg, "text_config"):
cfg = cfg.text_config
# MLA models: Q head_dim = qk_nope + qk_rope, V head_dim = v_head_dim
if hasattr(cfg, "qk_nope_head_dim") and hasattr(cfg, "qk_rope_head_dim"):
head_dim = cfg.qk_nope_head_dim + cfg.qk_rope_head_dim
head_dim_v = getattr(cfg, "v_head_dim", head_dim)
return head_dim, head_dim_v
# Standard models
if hasattr(cfg, "head_dim"):
return cfg.head_dim, cfg.head_dim
if hasattr(cfg, "hidden_size") and hasattr(cfg, "num_attention_heads"):
head_dim = cfg.hidden_size // cfg.num_attention_heads
return head_dim, head_dim
return None, None
def patch_flash_attn_4(model_config=None):
"""Patch _lazy_imports to redirect FA2 imports to FA4 if available on supported hardware."""
if not torch.cuda.is_available():
return
major, _ = torch.cuda.get_device_capability()
# Matches flash_attn/cute/interface.py: arch / 10 in [9, 10, 11]
if major not in (9, 10, 11):
return
try:
from flash_attn.cute import ( # noqa: F401
flash_attn_func,
flash_attn_varlen_func,
)
except ImportError:
LOG.info(
"Flash Attention 4 is available for your GPU and offers faster training speeds. "
"To enable: pip install flash-attn-4"
)
return
# Validate head dimensions against FA4's own constraints
head_dim = None
if model_config is not None:
head_dim, head_dim_v = _get_head_dims(model_config)
if head_dim is not None:
try:
from flash_attn.cute.interface import _validate_head_dims
except ImportError:
LOG.warning(
"Could not import _validate_head_dims from flash_attn.cute.interface, "
"unable to verify head dimension compatibility, falling back to FA2"
)
return
# alignment = 16 // element_size; bf16/fp16 = 2 bytes -> alignment = 8
alignment = 8
try:
_validate_head_dims(head_dim, head_dim_v, major, alignment)
except AssertionError as exc:
LOG.warning(
"Model head dimensions not supported by FA4, "
"falling back to FA2: %s",
exc,
)
return
import transformers.modeling_flash_attention_utils as fa_utils
if getattr(fa_utils._lazy_imports, "_axolotl_patched", False):
return
def _patched_lazy_imports(
implementation, attention_wrapper=None, allow_all_kernels=False
):
return (
flash_attn_func,
flash_attn_varlen_func,
fa_utils._pad_input,
fa_utils._unpad_input,
)
_patched_lazy_imports._axolotl_patched = True
fa_utils._lazy_imports = _patched_lazy_imports
LOG.info(
"Flash Attention 4 enabled (head_dim=%s)",
head_dim if model_config else "unknown",
)