refactor attention handling

This commit is contained in:
Wing Lian
2026-04-01 16:57:34 +00:00
committed by Wing Lian
parent c4f986874d
commit aee8c75d64
10 changed files with 476 additions and 29 deletions

View File

@@ -634,6 +634,23 @@ class ModelLoader:
def _set_attention_config(self): def _set_attention_config(self):
"""Sample packing uses custom FA2 patch""" """Sample packing uses custom FA2 patch"""
# Map attn_implementation enum values to HF attn_implementation strings.
# xformers/sage are registered in ALL_ATTENTION_FUNCTIONS and
# ALL_MASK_ATTENTION_FUNCTIONS under their own names with FA2 mask
# behavior, so they no longer need to masquerade as flash_attention_2.
# s2 still uses flash_attention_2 because it modifies FA2 internals.
# Hub kernel strings (e.g. "kernels-community/flash-attn3") fall
# through the .get() and are passed to HF unchanged.
_ATTN_IMPL_TO_HF = {
"eager": "eager",
"flash": "flash_attention_2",
"sdpa": "sdpa",
"xformers": "xformers",
"flex": "flex_attention",
"sage": "sage",
"s2": "flash_attention_2",
"fp8": "sdpa",
}
if self.cfg.gemma4_hybrid_attn_impl: if self.cfg.gemma4_hybrid_attn_impl:
# Load model with flash_attention_2 for sliding window layers; # Load model with flash_attention_2 for sliding window layers;
# global layers will be patched to sdpa post-load. # global layers will be patched to sdpa post-load.
@@ -642,11 +659,14 @@ class ModelLoader:
# Set flash_attention so multipack/sample_packing patches activate # Set flash_attention so multipack/sample_packing patches activate
self.cfg.flash_attention = True self.cfg.flash_attention = True
elif self.cfg.attn_implementation: elif self.cfg.attn_implementation:
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation hf_impl = _ATTN_IMPL_TO_HF.get(
self.cfg.attn_implementation, self.cfg.attn_implementation
)
self.model_kwargs["attn_implementation"] = hf_impl
self.model_config._attn_implementation = hf_impl
elif self.cfg.flex_attention: elif self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention" self.model_kwargs["attn_implementation"] = "flex_attention"
self.model_config._attn_implementation = "flex_attention" self.model_config._attn_implementation = "flex_attention"
elif self.cfg.flash_attention: elif self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention: if not self.cfg.sample_packing and self.cfg.s2_attention:
pass pass

View File

@@ -172,6 +172,7 @@ class PatchManager:
self._apply_llama_flash_attn_patches(model) self._apply_llama_flash_attn_patches(model)
self._apply_lora_kernel_patch(model) self._apply_lora_kernel_patch(model)
self._apply_scaling_softmax_patch(model) self._apply_scaling_softmax_patch(model)
self._apply_fp8_attention_patches(model)
def _apply_gemma_hybrid_attention(self, model: PreTrainedModel): def _apply_gemma_hybrid_attention(self, model: PreTrainedModel):
"""Apply hybrid attention: FA2 for sliding window layers, SDPA for global layers. """Apply hybrid attention: FA2 for sliding window layers, SDPA for global layers.
@@ -252,11 +253,29 @@ class PatchManager:
def _apply_flash_attention_patches(self): def _apply_flash_attention_patches(self):
"""Apply patches related to Flash Attention.""" """Apply patches related to Flash Attention."""
if self.cfg.xformers_attention and self.cfg.sample_packing: if self.cfg.xformers_attention:
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2 from axolotl.monkeypatch.attention import register_xformers_attn
patch_xformers_attn_over_fa2() register_xformers_attn()
self.cfg.flash_attention = True
if self.cfg.sample_packing:
# Also patch FA2 slot for legacy code paths that use it directly
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
patch_xformers_attn_over_fa2()
self.cfg.flash_attention = True
if self.cfg.sage_attention:
from axolotl.monkeypatch.attention import register_sage_attn
register_sage_attn()
def _apply_fp8_attention_patches(self, model):
"""Apply FP8 low-precision attention via torchao."""
if self.cfg.attn_implementation == "fp8":
from axolotl.monkeypatch.attention.fp8_attn import patch_fp8_attention
patch_fp8_attention(model)
def _apply_chunked_cross_entropy_patch(self): def _apply_chunked_cross_entropy_patch(self):
if self.cfg.chunked_cross_entropy: if self.cfg.chunked_cross_entropy:

View File

@@ -17,3 +17,29 @@ def unpatch_xformers_attn_over_fa2():
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward() ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward()
def register_xformers_attn():
"""Register xformers as its own attention backend with FA2 mask behavior."""
from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from .xformers import xformers_attention_forward
ALL_ATTENTION_FUNCTIONS.register("xformers", xformers_attention_forward)
ALL_MASK_ATTENTION_FUNCTIONS.register(
"xformers", ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
)
def register_sage_attn():
"""Register sage as its own attention backend with FA2 mask behavior."""
from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from .sage_attn import sage_attention_forward
ALL_ATTENTION_FUNCTIONS.register("sage", sage_attention_forward)
ALL_MASK_ATTENTION_FUNCTIONS.register(
"sage", ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
)

View File

@@ -0,0 +1,30 @@
"""FP8 low-precision attention via torchao.
Requires:
- PyTorch >= 2.11.0
- SM90+ (Hopper/Blackwell) GPU
- flash-attn package with FA3 support
- torchao >= 0.17.0
Uses per-head FP8 quantized attention with automatic RoPE fusion under torch.compile.
The torchao patch replaces F.scaled_dot_product_attention, so the model must use
HF's "sdpa" attention implementation for the patch to intercept attention calls.
"""
import logging
import torch
LOG = logging.getLogger(__name__)
def patch_fp8_attention(model: torch.nn.Module) -> torch.nn.Module:
"""Apply FP8 low-precision attention to a model.
Must be called after model loading and before torch.compile.
KV caching should be disabled (config.use_cache = False).
"""
from torchao.prototype.attention import apply_low_precision_attention
LOG.info("Applying FP8 low-precision attention (torchao)")
return apply_low_precision_attention(model)

View File

@@ -191,21 +191,9 @@ def sage_attention_forward(
def patch_sageattn(): def patch_sageattn():
"""Patch SageAttention for use with transformers.""" """Validate SageAttention is available. Registration in the attention/mask
function registries is handled by register_sage_attn() in __init__.py."""
_check_sageattn_imported() _check_sageattn_imported()
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS LOG.info("SageAttention validated successfully")
# Replace flash attention with sage attention
ALL_ATTENTION_FUNCTIONS.register("flash_attention_2", sage_attention_forward)
# Note: New method after transformers refactor to use ALL_MASK_ATTENTION_FUNCTIONS
# Register sage_attention with the global attention interface
# ALL_ATTENTION_FUNCTIONS.register("sage_attention", sage_attention_forward)
# from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, flash_attention_mask
# ALL_MASK_ATTENTION_FUNCTIONS.register("sage_attention", flash_attention_mask)
LOG.info("SageAttention patched successfully")

View File

@@ -955,7 +955,10 @@ def colab_inference_post_train_callback(trainer: Trainer):
""" """
handle T4 gpu, we need to convert attention to eager for inference handle T4 gpu, we need to convert attention to eager for inference
""" """
if "Tesla T4" in self.gpu_name and self.cfg.xformers_attention: if "Tesla T4" in self.gpu_name and (
self.cfg.xformers_attention
or self.cfg.attn_implementation == "xformers"
):
trainer.model.config._attn_implementation = "eager" trainer.model.config._attn_implementation = "eager"
trainer.model.gradient_checkpointing_disable() trainer.model.gradient_checkpointing_disable()
trainer.model.config.use_cache = True trainer.model.config.use_cache = True

View File

@@ -27,7 +27,12 @@ from axolotl.utils.schemas.datasets import (
) )
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType from axolotl.utils.schemas.enums import (
AttnImplementation,
ChatTemplate,
RingAttnFunc,
RLType,
)
from axolotl.utils.schemas.fsdp import FSDPConfig from axolotl.utils.schemas.fsdp import FSDPConfig
from axolotl.utils.schemas.integrations import ( from axolotl.utils.schemas.integrations import (
CometConfig, CometConfig,
@@ -786,10 +791,10 @@ class AxolotlInputConfig(
eager_attention: bool | None = None eager_attention: bool | None = None
attn_implementation: str | None = Field( attn_implementation: AttnImplementation | str | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
"description": "Specify a custom attention implementation, used mostly for kernels." "description": "Attention backend: eager, flash, sdpa, xformers, flex, sage, s2, fp8, or a custom string for kernels."
}, },
) )
@@ -1347,6 +1352,81 @@ class AxolotlInputConfig(
) )
return data return data
@model_validator(mode="before")
@classmethod
def normalize_attn_implementation(cls, data):
"""Normalize attention config: map between attn_implementation enum and legacy boolean flags."""
attn_impl = data.get("attn_implementation")
# Mapping: attn_implementation value -> (primary flag, extra flags to set)
impl_to_flags = {
"eager": (("eager_attention",), ()),
"flash": (("flash_attention",), ()),
"sdpa": (("sdp_attention",), ()),
"xformers": (("xformers_attention",), ("flash_attention",)),
"flex": (("flex_attention",), ()),
"sage": (("sage_attention",), ("flash_attention",)),
"s2": (("s2_attention",), ("flash_attention",)),
"fp8": ((), ()), # new, no legacy flags
}
# Reverse mapping: legacy flag -> attn_implementation value
flag_to_impl = {
"eager_attention": "eager",
"flash_attention": "flash",
"sdp_attention": "sdpa",
"xformers_attention": "xformers",
"flex_attention": "flex",
"sage_attention": "sage",
"s2_attention": "s2",
}
# Find which legacy flags are set
set_flags = [f for f, impl in flag_to_impl.items() if data.get(f)]
if attn_impl and set_flags:
# Both set — check consistency
if attn_impl in impl_to_flags:
expected_primary, expected_extra = impl_to_flags[attn_impl]
expected_flags = set(expected_primary) | set(expected_extra)
for flag in set_flags:
if flag not in expected_flags:
raise ValueError(
f"attn_implementation={attn_impl!r} conflicts with {flag}=true. "
f"Use only attn_implementation or the legacy flag, not both."
)
elif attn_impl and not set_flags:
# attn_implementation set, no legacy flags — set them for backwards compat
if attn_impl in impl_to_flags:
primary, extra = impl_to_flags[attn_impl]
for flag in (*primary, *extra):
data[flag] = True
elif not attn_impl and set_flags:
# Legacy flags set, no attn_implementation — map to enum, warn
# Priority: specific backends first, then generic flash/sdp/eager
# s2 and sage require flash_attention internally, so they must be
# checked before flash_attention to avoid masking
priority = [
"xformers_attention",
"s2_attention",
"sage_attention",
"flex_attention",
"flash_attention",
"sdp_attention",
"eager_attention",
]
for flag in priority:
if flag in set_flags:
data["attn_implementation"] = flag_to_impl[flag]
LOG.warning(
"`%s: true` is deprecated. Use `attn_implementation: %s` instead.",
flag,
flag_to_impl[flag],
)
break
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_sageattn_wo_sample_packing(cls, data): def check_sageattn_wo_sample_packing(cls, data):

View File

@@ -97,6 +97,19 @@ class CustomSupportedOptimizers(str, Enum):
flash_lion = "flash_lion" flash_lion = "flash_lion"
class AttnImplementation(str, Enum):
"""Attention backend implementations"""
eager = "eager" # pylint: disable=invalid-name
flash = "flash" # pylint: disable=invalid-name
sdpa = "sdpa" # pylint: disable=invalid-name
xformers = "xformers" # pylint: disable=invalid-name
flex = "flex" # pylint: disable=invalid-name
sage = "sage" # pylint: disable=invalid-name
s2 = "s2" # pylint: disable=invalid-name
fp8 = "fp8" # pylint: disable=invalid-name
class RingAttnFunc(str, Enum): class RingAttnFunc(str, Enum):
"""Enum class for supported `ring-flash-attn` implementations""" """Enum class for supported `ring-flash-attn` implementations"""

View File

@@ -201,6 +201,7 @@ class AttentionValidationMixin:
def check_sample_packing_without_attention(cls, data): def check_sample_packing_without_attention(cls, data):
if ( if (
data.get("sample_packing") data.get("sample_packing")
and not data.get("attn_implementation")
and not data.get("flash_attention") and not data.get("flash_attention")
and not data.get("sdp_attention") and not data.get("sdp_attention")
and not data.get("flex_attention") and not data.get("flex_attention")
@@ -215,7 +216,9 @@ class AttentionValidationMixin:
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_sample_packing_with_s2attn(cls, data): def check_sample_packing_with_s2attn(cls, data):
if data.get("sample_packing") and data.get("s2_attention"): if data.get("sample_packing") and (
data.get("s2_attention") or data.get("attn_implementation") == "s2"
):
raise ValueError( raise ValueError(
"Received `sample_packing=true` and `s2_attention=true`; however, \ "Received `sample_packing=true` and `s2_attention=true`; however, \
shifted-sparse attention does not currently support sample packing." shifted-sparse attention does not currently support sample packing."
@@ -225,10 +228,12 @@ class AttentionValidationMixin:
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_scaling_softmax_requires_flex(cls, data): def check_scaling_softmax_requires_flex(cls, data):
if data.get("scaling_softmax") and not data.get("flex_attention"): if data.get("scaling_softmax") and not (
data.get("flex_attention") or data.get("attn_implementation") == "flex"
):
raise ValueError( raise ValueError(
"scaling_softmax requires flex_attention: true\n" "scaling_softmax requires flex attention.\n"
"Add 'flex_attention: true' to your config file.\n" "Add 'attn_implementation: flex' to your config file.\n"
) )
return data return data

View File

@@ -0,0 +1,263 @@
"""
Tests for attn_implementation normalization, registry registration, and
backwards compatibility with legacy boolean attention flags.
"""
import pytest
from axolotl.utils.schemas.config import AxolotlInputConfig
class TestAttnImplementationNormalizer:
"""Test the normalize_attn_implementation validator."""
@staticmethod
def _normalize(data):
return AxolotlInputConfig.normalize_attn_implementation(data)
# --- Forward mapping: attn_implementation -> legacy flags ---
@pytest.mark.parametrize(
"impl,expected_flags",
[
("eager", {"eager_attention": True}),
("flash", {"flash_attention": True}),
("sdpa", {"sdp_attention": True}),
("flex", {"flex_attention": True}),
("xformers", {"xformers_attention": True, "flash_attention": True}),
("sage", {"sage_attention": True, "flash_attention": True}),
("s2", {"s2_attention": True, "flash_attention": True}),
],
)
def test_attn_impl_sets_legacy_flags(self, impl, expected_flags):
data = {"attn_implementation": impl}
result = AxolotlInputConfig.normalize_attn_implementation(data)
for flag, val in expected_flags.items():
assert result.get(flag) == val, f"{impl}: expected {flag}={val}"
def test_fp8_sets_no_legacy_flags(self):
result = self._normalize({"attn_implementation": "fp8"})
for flag in [
"flash_attention",
"sdp_attention",
"eager_attention",
"xformers_attention",
"sage_attention",
"flex_attention",
"s2_attention",
]:
assert not result.get(flag), f"fp8 should not set {flag}"
# --- Reverse mapping: legacy flags -> attn_implementation ---
@pytest.mark.parametrize(
"flag,expected_impl",
[
("flash_attention", "flash"),
("sdp_attention", "sdpa"),
("xformers_attention", "xformers"),
("flex_attention", "flex"),
("sage_attention", "sage"),
("eager_attention", "eager"),
("s2_attention", "s2"),
],
)
def test_legacy_flag_sets_attn_impl(self, flag, expected_impl):
result = self._normalize({flag: True})
assert result["attn_implementation"] == expected_impl
# --- Priority: s2/sage should win over flash when both set ---
def test_s2_plus_flash_maps_to_s2(self):
"""Legacy configs often have both s2_attention and flash_attention."""
result = self._normalize({"s2_attention": True, "flash_attention": True})
assert result["attn_implementation"] == "s2"
def test_sage_plus_flash_maps_to_sage(self):
"""sage_attention should take priority over flash_attention."""
result = self._normalize({"sage_attention": True, "flash_attention": True})
assert result["attn_implementation"] == "sage"
# --- Consistency: both set, matching ---
def test_consistent_both_set_no_error(self):
result = self._normalize(
{"attn_implementation": "flash", "flash_attention": True}
)
assert result["attn_implementation"] == "flash"
assert result["flash_attention"] is True
def test_consistent_xformers_with_extra_flags(self):
"""xformers needs flash_attention=True, so both flags with attn_impl should be OK."""
result = self._normalize(
{
"attn_implementation": "xformers",
"xformers_attention": True,
"flash_attention": True,
}
)
assert result["attn_implementation"] == "xformers"
def test_consistent_s2_with_flash(self):
result = self._normalize(
{
"attn_implementation": "s2",
"s2_attention": True,
"flash_attention": True,
}
)
assert result["attn_implementation"] == "s2"
# --- Conflict detection ---
def test_conflicting_impl_and_flag_raises(self):
with pytest.raises(ValueError, match="conflicts with"):
self._normalize({"attn_implementation": "flash", "sdp_attention": True})
def test_conflicting_xformers_impl_with_sdp_flag(self):
with pytest.raises(ValueError, match="conflicts with"):
self._normalize({"attn_implementation": "xformers", "sdp_attention": True})
# --- Hub kernel strings pass through ---
def test_hub_kernel_passthrough(self):
result = self._normalize(
{"attn_implementation": "kernels-community/flash-attn3"}
)
assert result["attn_implementation"] == "kernels-community/flash-attn3"
# Should not set any legacy flags
for flag in [
"flash_attention",
"sdp_attention",
"eager_attention",
"xformers_attention",
]:
assert not result.get(flag)
def test_custom_string_passthrough(self):
result = self._normalize({"attn_implementation": "my_custom_kernel"})
assert result["attn_implementation"] == "my_custom_kernel"
# --- No attention set ---
def test_no_attention_set_is_noop(self):
result = self._normalize({"some_other_config": True})
assert result.get("attn_implementation") is None
# --- Sample packing interactions ---
def test_xformers_with_sample_packing_sets_flash(self):
"""xformers + sample_packing needs flash_attention=True for the patch chain."""
result = self._normalize(
{"attn_implementation": "xformers", "sample_packing": True}
)
assert result["xformers_attention"] is True
assert result["flash_attention"] is True
class TestAttnImplToHFMapping:
"""Test that attn_implementation enum values map correctly to HF strings."""
# This dict mirrors _ATTN_IMPL_TO_HF in model.py
_ATTN_IMPL_TO_HF = {
"eager": "eager",
"flash": "flash_attention_2",
"sdpa": "sdpa",
"xformers": "xformers",
"flex": "flex_attention",
"sage": "sage",
"s2": "flash_attention_2",
"fp8": "sdpa",
}
@pytest.mark.parametrize(
"impl,expected_hf",
[
("eager", "eager"),
("flash", "flash_attention_2"),
("sdpa", "sdpa"),
("xformers", "xformers"),
("flex", "flex_attention"),
("sage", "sage"),
("s2", "flash_attention_2"),
("fp8", "sdpa"),
],
)
def test_known_impl_maps_correctly(self, impl, expected_hf):
assert self._ATTN_IMPL_TO_HF[impl] == expected_hf
def test_hub_kernel_falls_through(self):
"""Hub kernel strings should pass through .get() unchanged."""
hub_str = "kernels-community/flash-attn3"
result = self._ATTN_IMPL_TO_HF.get(hub_str, hub_str)
assert result == hub_str
def _xformers_available():
try:
import xformers.ops # noqa: F401
return True
except (ImportError, OSError):
return False
class TestAttentionRegistration:
"""Test that attention backends register correctly in HF's registries."""
@pytest.mark.skipif(not _xformers_available(), reason="xformers not available")
def test_register_xformers(self):
from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.monkeypatch.attention import register_xformers_attn
register_xformers_attn()
assert "xformers" in ALL_ATTENTION_FUNCTIONS
assert "xformers" in ALL_MASK_ATTENTION_FUNCTIONS
# xformers mask should be the same function as flash_attention_2's mask
assert (
ALL_MASK_ATTENTION_FUNCTIONS["xformers"]
== ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
)
def test_register_sage(self):
from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.monkeypatch.attention import register_sage_attn
register_sage_attn()
assert "sage" in ALL_ATTENTION_FUNCTIONS
assert "sage" in ALL_MASK_ATTENTION_FUNCTIONS
assert (
ALL_MASK_ATTENTION_FUNCTIONS["sage"]
== ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
)
@pytest.mark.skipif(not _xformers_available(), reason="xformers not available")
def test_xformers_does_not_overwrite_fa2(self):
"""Registering xformers should not modify the flash_attention_2 slot."""
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
original_fa2 = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]
from axolotl.monkeypatch.attention import register_xformers_attn
register_xformers_attn()
assert ALL_ATTENTION_FUNCTIONS["flash_attention_2"] is original_fa2
def test_sage_does_not_overwrite_fa2(self):
"""Registering sage should not modify the flash_attention_2 slot."""
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
original_fa2 = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]
from axolotl.monkeypatch.attention import register_sage_attn
register_sage_attn()
assert ALL_ATTENTION_FUNCTIONS["flash_attention_2"] is original_fa2