feat: add FA4 (#3481)
* feat: add FA4 * chore: update docs * fix: recommend FA4 for those with compatible devices * fix: adjust import check and add head_dim check * chore: add limitation to doc * fix: log warning and quit if cannot import validator * chore: simplify * fix: add caveat with FA2 shadow dir
This commit is contained in:
@@ -75,7 +75,7 @@ Features:
|
|||||||
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support.
|
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support.
|
||||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||||
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
|
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
|
||||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [SageAttention](https://github.com/thu-ml/SageAttention), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention 2/3/4](https://docs.axolotl.ai/docs/attention.html#flash-attention), [Xformers](https://docs.axolotl.ai/docs/attention.html#xformers), [Flex Attention](https://docs.axolotl.ai/docs/attention.html#flex-attention), [SageAttention](https://docs.axolotl.ai/docs/attention.html#sageattention), [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels), [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
||||||
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
||||||
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
||||||
|
|
||||||
|
|||||||
@@ -13,9 +13,10 @@ sdp_attention: true
|
|||||||
|
|
||||||
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||||
|
|
||||||
## Flash Attention 2
|
## Flash Attention
|
||||||
|
|
||||||
Uses efficient kernels to compute attention.
|
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
|
||||||
|
based on your installed packages and GPU.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
flash_attention: true
|
flash_attention: true
|
||||||
@@ -23,11 +24,9 @@ flash_attention: true
|
|||||||
|
|
||||||
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
|
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
|
||||||
|
|
||||||
### Nvidia
|
### Flash Attention 2
|
||||||
|
|
||||||
Requirements: Ampere, Ada, or Hopper GPUs
|
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
|
||||||
|
|
||||||
Note: For Turing GPUs or lower, please use other attention methods.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install flash-attn --no-build-isolation
|
pip install flash-attn --no-build-isolation
|
||||||
@@ -35,11 +34,12 @@ pip install flash-attn --no-build-isolation
|
|||||||
|
|
||||||
::: {.callout-tip}
|
::: {.callout-tip}
|
||||||
|
|
||||||
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
|
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl.
|
||||||
|
Alternatively, try reinstall or downgrade a version.
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
#### Flash Attention 3
|
### Flash Attention 3
|
||||||
|
|
||||||
Requirements: Hopper only and CUDA 12.8 (recommended)
|
Requirements: Hopper only and CUDA 12.8 (recommended)
|
||||||
|
|
||||||
@@ -50,6 +50,44 @@ cd flash-attention/hopper
|
|||||||
python setup.py install
|
python setup.py install
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Flash Attention 4
|
||||||
|
|
||||||
|
Requirements: Hopper or Blackwell GPUs
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install flash-attn-4
|
||||||
|
```
|
||||||
|
|
||||||
|
Or from source:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||||
|
cd flash-attention/flash_attn/cute
|
||||||
|
|
||||||
|
pip install -e .
|
||||||
|
|
||||||
|
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
|
||||||
|
# Remove it so Python can find the real FA4 module:
|
||||||
|
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute
|
||||||
|
```
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
|
||||||
|
**Hopper (SM90) users**: The backward kernel is not yet included in the pip package. To use FA4
|
||||||
|
for training on Hopper, install from source using the instructions above.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
::: {.callout-warning}
|
||||||
|
|
||||||
|
FA4 only supports head dimensions up to 128 (`d ≤ 128`). The DeepSeek shape `(192, 128)` is
|
||||||
|
also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions
|
||||||
|
and falls back to FA2/3.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
|
||||||
|
|
||||||
### AMD
|
### AMD
|
||||||
|
|
||||||
Requirements: ROCm 6.0 and above.
|
Requirements: ROCm 6.0 and above.
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ class PatchManager:
|
|||||||
self._apply_flash_attention_patches()
|
self._apply_flash_attention_patches()
|
||||||
self._apply_chunked_cross_entropy_patch()
|
self._apply_chunked_cross_entropy_patch()
|
||||||
self._apply_sageattn_patches()
|
self._apply_sageattn_patches()
|
||||||
|
self._apply_flash_attn_4_patches()
|
||||||
self._apply_fsdp_patches()
|
self._apply_fsdp_patches()
|
||||||
self._apply_adapter_patches()
|
self._apply_adapter_patches()
|
||||||
self._apply_model_specific_patches()
|
self._apply_model_specific_patches()
|
||||||
@@ -228,6 +229,15 @@ class PatchManager:
|
|||||||
|
|
||||||
patch_sageattn()
|
patch_sageattn()
|
||||||
|
|
||||||
|
def _apply_flash_attn_4_patches(self):
|
||||||
|
"""Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+."""
|
||||||
|
if not self.cfg.flash_attention:
|
||||||
|
return
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4
|
||||||
|
|
||||||
|
patch_flash_attn_4(self.model_config)
|
||||||
|
|
||||||
def _apply_model_specific_patches(self):
|
def _apply_model_specific_patches(self):
|
||||||
"""Apply patches specific to model architectures."""
|
"""Apply patches specific to model architectures."""
|
||||||
if (
|
if (
|
||||||
|
|||||||
104
src/axolotl/monkeypatch/attention/flash_attn_4.py
Normal file
104
src/axolotl/monkeypatch/attention/flash_attn_4.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""Transparently upgrade FA2 to FA4 when available on SM90+ hardware."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_head_dims(model_config):
|
||||||
|
"""Extract (head_dim, head_dim_v) from a model config.
|
||||||
|
|
||||||
|
Handles composite models (e.g. Qwen3.5 VL) via text_config and
|
||||||
|
MLA models (DeepSeek/Kimi) that have separate Q/V head dimensions.
|
||||||
|
"""
|
||||||
|
cfg = model_config
|
||||||
|
if hasattr(cfg, "text_config"):
|
||||||
|
cfg = cfg.text_config
|
||||||
|
|
||||||
|
# MLA models: Q head_dim = qk_nope + qk_rope, V head_dim = v_head_dim
|
||||||
|
if hasattr(cfg, "qk_nope_head_dim") and hasattr(cfg, "qk_rope_head_dim"):
|
||||||
|
head_dim = cfg.qk_nope_head_dim + cfg.qk_rope_head_dim
|
||||||
|
head_dim_v = getattr(cfg, "v_head_dim", head_dim)
|
||||||
|
return head_dim, head_dim_v
|
||||||
|
|
||||||
|
# Standard models
|
||||||
|
if hasattr(cfg, "head_dim"):
|
||||||
|
return cfg.head_dim, cfg.head_dim
|
||||||
|
if hasattr(cfg, "hidden_size") and hasattr(cfg, "num_attention_heads"):
|
||||||
|
head_dim = cfg.hidden_size // cfg.num_attention_heads
|
||||||
|
return head_dim, head_dim
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def patch_flash_attn_4(model_config=None):
|
||||||
|
"""Patch _lazy_imports to redirect FA2 imports to FA4 if available on supported hardware."""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return
|
||||||
|
|
||||||
|
major, _ = torch.cuda.get_device_capability()
|
||||||
|
# Matches flash_attn/cute/interface.py: arch / 10 in [9, 10, 11]
|
||||||
|
if major not in (9, 10, 11):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flash_attn.cute import ( # noqa: F401
|
||||||
|
flash_attn_func,
|
||||||
|
flash_attn_varlen_func,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
LOG.info(
|
||||||
|
"Flash Attention 4 is available for your GPU and offers faster training speeds. "
|
||||||
|
"To enable: pip install flash-attn-4"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Validate head dimensions against FA4's own constraints
|
||||||
|
head_dim = None
|
||||||
|
if model_config is not None:
|
||||||
|
head_dim, head_dim_v = _get_head_dims(model_config)
|
||||||
|
if head_dim is not None:
|
||||||
|
try:
|
||||||
|
from flash_attn.cute.interface import _validate_head_dims
|
||||||
|
except ImportError:
|
||||||
|
LOG.warning(
|
||||||
|
"Could not import _validate_head_dims from flash_attn.cute.interface, "
|
||||||
|
"unable to verify head dimension compatibility, falling back to FA2"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# alignment = 16 // element_size; bf16/fp16 = 2 bytes -> alignment = 8
|
||||||
|
alignment = 8
|
||||||
|
try:
|
||||||
|
_validate_head_dims(head_dim, head_dim_v, major, alignment)
|
||||||
|
except AssertionError as exc:
|
||||||
|
LOG.warning(
|
||||||
|
"Model head dimensions not supported by FA4, "
|
||||||
|
"falling back to FA2: %s",
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
import transformers.modeling_flash_attention_utils as fa_utils
|
||||||
|
|
||||||
|
if getattr(fa_utils._lazy_imports, "_axolotl_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
def _patched_lazy_imports(
|
||||||
|
implementation, attention_wrapper=None, allow_all_kernels=False
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
flash_attn_func,
|
||||||
|
flash_attn_varlen_func,
|
||||||
|
fa_utils._pad_input,
|
||||||
|
fa_utils._unpad_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
_patched_lazy_imports._axolotl_patched = True
|
||||||
|
fa_utils._lazy_imports = _patched_lazy_imports
|
||||||
|
LOG.info(
|
||||||
|
"Flash Attention 4 enabled (head_dim=%s)",
|
||||||
|
head_dim if model_config else "unknown",
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user