diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 4f5779327..c2dbf00aa 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -634,6 +634,23 @@ class ModelLoader: def _set_attention_config(self): """Sample packing uses custom FA2 patch""" + # Map attn_implementation enum values to HF attn_implementation strings. + # xformers/sage are registered in ALL_ATTENTION_FUNCTIONS and + # ALL_MASK_ATTENTION_FUNCTIONS under their own names with FA2 mask + # behavior, so they no longer need to masquerade as flash_attention_2. + # s2 still uses flash_attention_2 because it modifies FA2 internals. + # Hub kernel strings (e.g. "kernels-community/flash-attn3") fall + # through the .get() and are passed to HF unchanged. + _ATTN_IMPL_TO_HF = { + "eager": "eager", + "flash": "flash_attention_2", + "sdpa": "sdpa", + "xformers": "xformers", + "flex": "flex_attention", + "sage": "sage", + "s2": "flash_attention_2", + "fp8": "sdpa", + } if self.cfg.gemma4_hybrid_attn_impl: # Load model with flash_attention_2 for sliding window layers; # global layers will be patched to sdpa post-load. @@ -642,11 +659,14 @@ class ModelLoader: # Set flash_attention so multipack/sample_packing patches activate self.cfg.flash_attention = True elif self.cfg.attn_implementation: - self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation + hf_impl = _ATTN_IMPL_TO_HF.get( + self.cfg.attn_implementation, self.cfg.attn_implementation + ) + self.model_kwargs["attn_implementation"] = hf_impl + self.model_config._attn_implementation = hf_impl elif self.cfg.flex_attention: self.model_kwargs["attn_implementation"] = "flex_attention" self.model_config._attn_implementation = "flex_attention" - elif self.cfg.flash_attention: if not self.cfg.sample_packing and self.cfg.s2_attention: pass diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 01d9997d7..ebe0e6474 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -172,6 +172,7 @@ class PatchManager: self._apply_llama_flash_attn_patches(model) self._apply_lora_kernel_patch(model) self._apply_scaling_softmax_patch(model) + self._apply_fp8_attention_patches(model) def _apply_gemma_hybrid_attention(self, model: PreTrainedModel): """Apply hybrid attention: FA2 for sliding window layers, SDPA for global layers. @@ -252,11 +253,29 @@ class PatchManager: def _apply_flash_attention_patches(self): """Apply patches related to Flash Attention.""" - if self.cfg.xformers_attention and self.cfg.sample_packing: - from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2 + if self.cfg.xformers_attention: + from axolotl.monkeypatch.attention import register_xformers_attn - patch_xformers_attn_over_fa2() - self.cfg.flash_attention = True + register_xformers_attn() + + if self.cfg.sample_packing: + # Also patch FA2 slot for legacy code paths that use it directly + from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2 + + patch_xformers_attn_over_fa2() + self.cfg.flash_attention = True + + if self.cfg.sage_attention: + from axolotl.monkeypatch.attention import register_sage_attn + + register_sage_attn() + + def _apply_fp8_attention_patches(self, model): + """Apply FP8 low-precision attention via torchao.""" + if self.cfg.attn_implementation == "fp8": + from axolotl.monkeypatch.attention.fp8_attn import patch_fp8_attention + + patch_fp8_attention(model) def _apply_chunked_cross_entropy_patch(self): if self.cfg.chunked_cross_entropy: diff --git a/src/axolotl/monkeypatch/attention/__init__.py b/src/axolotl/monkeypatch/attention/__init__.py index 15ed764f4..74bd61e77 100644 --- a/src/axolotl/monkeypatch/attention/__init__.py +++ b/src/axolotl/monkeypatch/attention/__init__.py @@ -17,3 +17,29 @@ def unpatch_xformers_attn_over_fa2(): from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward() + + +def register_xformers_attn(): + """Register xformers as its own attention backend with FA2 mask behavior.""" + from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + from .xformers import xformers_attention_forward + + ALL_ATTENTION_FUNCTIONS.register("xformers", xformers_attention_forward) + ALL_MASK_ATTENTION_FUNCTIONS.register( + "xformers", ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"] + ) + + +def register_sage_attn(): + """Register sage as its own attention backend with FA2 mask behavior.""" + from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + from .sage_attn import sage_attention_forward + + ALL_ATTENTION_FUNCTIONS.register("sage", sage_attention_forward) + ALL_MASK_ATTENTION_FUNCTIONS.register( + "sage", ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"] + ) diff --git a/src/axolotl/monkeypatch/attention/fp8_attn.py b/src/axolotl/monkeypatch/attention/fp8_attn.py new file mode 100644 index 000000000..224e8c3b7 --- /dev/null +++ b/src/axolotl/monkeypatch/attention/fp8_attn.py @@ -0,0 +1,30 @@ +"""FP8 low-precision attention via torchao. + +Requires: + - PyTorch >= 2.11.0 + - SM90+ (Hopper/Blackwell) GPU + - flash-attn package with FA3 support + - torchao >= 0.17.0 + +Uses per-head FP8 quantized attention with automatic RoPE fusion under torch.compile. +The torchao patch replaces F.scaled_dot_product_attention, so the model must use +HF's "sdpa" attention implementation for the patch to intercept attention calls. +""" + +import logging + +import torch + +LOG = logging.getLogger(__name__) + + +def patch_fp8_attention(model: torch.nn.Module) -> torch.nn.Module: + """Apply FP8 low-precision attention to a model. + + Must be called after model loading and before torch.compile. + KV caching should be disabled (config.use_cache = False). + """ + from torchao.prototype.attention import apply_low_precision_attention + + LOG.info("Applying FP8 low-precision attention (torchao)") + return apply_low_precision_attention(model) diff --git a/src/axolotl/monkeypatch/attention/sage_attn.py b/src/axolotl/monkeypatch/attention/sage_attn.py index cc9fdb94d..6e9ba0f85 100644 --- a/src/axolotl/monkeypatch/attention/sage_attn.py +++ b/src/axolotl/monkeypatch/attention/sage_attn.py @@ -191,21 +191,9 @@ def sage_attention_forward( def patch_sageattn(): - """Patch SageAttention for use with transformers.""" + """Validate SageAttention is available. Registration in the attention/mask + function registries is handled by register_sage_attn() in __init__.py.""" _check_sageattn_imported() - from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS - - # Replace flash attention with sage attention - ALL_ATTENTION_FUNCTIONS.register("flash_attention_2", sage_attention_forward) - - # Note: New method after transformers refactor to use ALL_MASK_ATTENTION_FUNCTIONS - # Register sage_attention with the global attention interface - # ALL_ATTENTION_FUNCTIONS.register("sage_attention", sage_attention_forward) - - # from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, flash_attention_mask - - # ALL_MASK_ATTENTION_FUNCTIONS.register("sage_attention", flash_attention_mask) - - LOG.info("SageAttention patched successfully") + LOG.info("SageAttention validated successfully") diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 8137bac0c..5635e1261 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -955,7 +955,10 @@ def colab_inference_post_train_callback(trainer: Trainer): """ handle T4 gpu, we need to convert attention to eager for inference """ - if "Tesla T4" in self.gpu_name and self.cfg.xformers_attention: + if "Tesla T4" in self.gpu_name and ( + self.cfg.xformers_attention + or self.cfg.attn_implementation == "xformers" + ): trainer.model.config._attn_implementation = "eager" trainer.model.gradient_checkpointing_disable() trainer.model.config.use_cache = True diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 6ee672c8c..6c579efa5 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -27,7 +27,12 @@ from axolotl.utils.schemas.datasets import ( ) from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig -from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType +from axolotl.utils.schemas.enums import ( + AttnImplementation, + ChatTemplate, + RingAttnFunc, + RLType, +) from axolotl.utils.schemas.fsdp import FSDPConfig from axolotl.utils.schemas.integrations import ( CometConfig, @@ -786,10 +791,10 @@ class AxolotlInputConfig( eager_attention: bool | None = None - attn_implementation: str | None = Field( + attn_implementation: AttnImplementation | str | None = Field( default=None, json_schema_extra={ - "description": "Specify a custom attention implementation, used mostly for kernels." + "description": "Attention backend: eager, flash, sdpa, xformers, flex, sage, s2, fp8, or a custom string for kernels." }, ) @@ -1347,6 +1352,81 @@ class AxolotlInputConfig( ) return data + @model_validator(mode="before") + @classmethod + def normalize_attn_implementation(cls, data): + """Normalize attention config: map between attn_implementation enum and legacy boolean flags.""" + attn_impl = data.get("attn_implementation") + + # Mapping: attn_implementation value -> (primary flag, extra flags to set) + impl_to_flags = { + "eager": (("eager_attention",), ()), + "flash": (("flash_attention",), ()), + "sdpa": (("sdp_attention",), ()), + "xformers": (("xformers_attention",), ("flash_attention",)), + "flex": (("flex_attention",), ()), + "sage": (("sage_attention",), ("flash_attention",)), + "s2": (("s2_attention",), ("flash_attention",)), + "fp8": ((), ()), # new, no legacy flags + } + + # Reverse mapping: legacy flag -> attn_implementation value + flag_to_impl = { + "eager_attention": "eager", + "flash_attention": "flash", + "sdp_attention": "sdpa", + "xformers_attention": "xformers", + "flex_attention": "flex", + "sage_attention": "sage", + "s2_attention": "s2", + } + + # Find which legacy flags are set + set_flags = [f for f, impl in flag_to_impl.items() if data.get(f)] + + if attn_impl and set_flags: + # Both set — check consistency + if attn_impl in impl_to_flags: + expected_primary, expected_extra = impl_to_flags[attn_impl] + expected_flags = set(expected_primary) | set(expected_extra) + for flag in set_flags: + if flag not in expected_flags: + raise ValueError( + f"attn_implementation={attn_impl!r} conflicts with {flag}=true. " + f"Use only attn_implementation or the legacy flag, not both." + ) + elif attn_impl and not set_flags: + # attn_implementation set, no legacy flags — set them for backwards compat + if attn_impl in impl_to_flags: + primary, extra = impl_to_flags[attn_impl] + for flag in (*primary, *extra): + data[flag] = True + elif not attn_impl and set_flags: + # Legacy flags set, no attn_implementation — map to enum, warn + # Priority: specific backends first, then generic flash/sdp/eager + # s2 and sage require flash_attention internally, so they must be + # checked before flash_attention to avoid masking + priority = [ + "xformers_attention", + "s2_attention", + "sage_attention", + "flex_attention", + "flash_attention", + "sdp_attention", + "eager_attention", + ] + for flag in priority: + if flag in set_flags: + data["attn_implementation"] = flag_to_impl[flag] + LOG.warning( + "`%s: true` is deprecated. Use `attn_implementation: %s` instead.", + flag, + flag_to_impl[flag], + ) + break + + return data + @model_validator(mode="before") @classmethod def check_sageattn_wo_sample_packing(cls, data): diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index d4ff27ac9..12d59c974 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -97,6 +97,19 @@ class CustomSupportedOptimizers(str, Enum): flash_lion = "flash_lion" +class AttnImplementation(str, Enum): + """Attention backend implementations""" + + eager = "eager" # pylint: disable=invalid-name + flash = "flash" # pylint: disable=invalid-name + sdpa = "sdpa" # pylint: disable=invalid-name + xformers = "xformers" # pylint: disable=invalid-name + flex = "flex" # pylint: disable=invalid-name + sage = "sage" # pylint: disable=invalid-name + s2 = "s2" # pylint: disable=invalid-name + fp8 = "fp8" # pylint: disable=invalid-name + + class RingAttnFunc(str, Enum): """Enum class for supported `ring-flash-attn` implementations""" diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index fff69de26..b83396d4e 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -201,6 +201,7 @@ class AttentionValidationMixin: def check_sample_packing_without_attention(cls, data): if ( data.get("sample_packing") + and not data.get("attn_implementation") and not data.get("flash_attention") and not data.get("sdp_attention") and not data.get("flex_attention") @@ -215,7 +216,9 @@ class AttentionValidationMixin: @model_validator(mode="before") @classmethod def check_sample_packing_with_s2attn(cls, data): - if data.get("sample_packing") and data.get("s2_attention"): + if data.get("sample_packing") and ( + data.get("s2_attention") or data.get("attn_implementation") == "s2" + ): raise ValueError( "Received `sample_packing=true` and `s2_attention=true`; however, \ shifted-sparse attention does not currently support sample packing." @@ -225,10 +228,12 @@ class AttentionValidationMixin: @model_validator(mode="before") @classmethod def check_scaling_softmax_requires_flex(cls, data): - if data.get("scaling_softmax") and not data.get("flex_attention"): + if data.get("scaling_softmax") and not ( + data.get("flex_attention") or data.get("attn_implementation") == "flex" + ): raise ValueError( - "scaling_softmax requires flex_attention: true\n" - "Add 'flex_attention: true' to your config file.\n" + "scaling_softmax requires flex attention.\n" + "Add 'attn_implementation: flex' to your config file.\n" ) return data diff --git a/tests/test_attn_implementation.py b/tests/test_attn_implementation.py new file mode 100644 index 000000000..71c052718 --- /dev/null +++ b/tests/test_attn_implementation.py @@ -0,0 +1,263 @@ +""" +Tests for attn_implementation normalization, registry registration, and +backwards compatibility with legacy boolean attention flags. +""" + +import pytest + +from axolotl.utils.schemas.config import AxolotlInputConfig + + +class TestAttnImplementationNormalizer: + """Test the normalize_attn_implementation validator.""" + + @staticmethod + def _normalize(data): + return AxolotlInputConfig.normalize_attn_implementation(data) + + # --- Forward mapping: attn_implementation -> legacy flags --- + + @pytest.mark.parametrize( + "impl,expected_flags", + [ + ("eager", {"eager_attention": True}), + ("flash", {"flash_attention": True}), + ("sdpa", {"sdp_attention": True}), + ("flex", {"flex_attention": True}), + ("xformers", {"xformers_attention": True, "flash_attention": True}), + ("sage", {"sage_attention": True, "flash_attention": True}), + ("s2", {"s2_attention": True, "flash_attention": True}), + ], + ) + def test_attn_impl_sets_legacy_flags(self, impl, expected_flags): + data = {"attn_implementation": impl} + result = AxolotlInputConfig.normalize_attn_implementation(data) + for flag, val in expected_flags.items(): + assert result.get(flag) == val, f"{impl}: expected {flag}={val}" + + def test_fp8_sets_no_legacy_flags(self): + result = self._normalize({"attn_implementation": "fp8"}) + for flag in [ + "flash_attention", + "sdp_attention", + "eager_attention", + "xformers_attention", + "sage_attention", + "flex_attention", + "s2_attention", + ]: + assert not result.get(flag), f"fp8 should not set {flag}" + + # --- Reverse mapping: legacy flags -> attn_implementation --- + + @pytest.mark.parametrize( + "flag,expected_impl", + [ + ("flash_attention", "flash"), + ("sdp_attention", "sdpa"), + ("xformers_attention", "xformers"), + ("flex_attention", "flex"), + ("sage_attention", "sage"), + ("eager_attention", "eager"), + ("s2_attention", "s2"), + ], + ) + def test_legacy_flag_sets_attn_impl(self, flag, expected_impl): + result = self._normalize({flag: True}) + assert result["attn_implementation"] == expected_impl + + # --- Priority: s2/sage should win over flash when both set --- + + def test_s2_plus_flash_maps_to_s2(self): + """Legacy configs often have both s2_attention and flash_attention.""" + result = self._normalize({"s2_attention": True, "flash_attention": True}) + assert result["attn_implementation"] == "s2" + + def test_sage_plus_flash_maps_to_sage(self): + """sage_attention should take priority over flash_attention.""" + result = self._normalize({"sage_attention": True, "flash_attention": True}) + assert result["attn_implementation"] == "sage" + + # --- Consistency: both set, matching --- + + def test_consistent_both_set_no_error(self): + result = self._normalize( + {"attn_implementation": "flash", "flash_attention": True} + ) + assert result["attn_implementation"] == "flash" + assert result["flash_attention"] is True + + def test_consistent_xformers_with_extra_flags(self): + """xformers needs flash_attention=True, so both flags with attn_impl should be OK.""" + result = self._normalize( + { + "attn_implementation": "xformers", + "xformers_attention": True, + "flash_attention": True, + } + ) + assert result["attn_implementation"] == "xformers" + + def test_consistent_s2_with_flash(self): + result = self._normalize( + { + "attn_implementation": "s2", + "s2_attention": True, + "flash_attention": True, + } + ) + assert result["attn_implementation"] == "s2" + + # --- Conflict detection --- + + def test_conflicting_impl_and_flag_raises(self): + with pytest.raises(ValueError, match="conflicts with"): + self._normalize({"attn_implementation": "flash", "sdp_attention": True}) + + def test_conflicting_xformers_impl_with_sdp_flag(self): + with pytest.raises(ValueError, match="conflicts with"): + self._normalize({"attn_implementation": "xformers", "sdp_attention": True}) + + # --- Hub kernel strings pass through --- + + def test_hub_kernel_passthrough(self): + result = self._normalize( + {"attn_implementation": "kernels-community/flash-attn3"} + ) + assert result["attn_implementation"] == "kernels-community/flash-attn3" + # Should not set any legacy flags + for flag in [ + "flash_attention", + "sdp_attention", + "eager_attention", + "xformers_attention", + ]: + assert not result.get(flag) + + def test_custom_string_passthrough(self): + result = self._normalize({"attn_implementation": "my_custom_kernel"}) + assert result["attn_implementation"] == "my_custom_kernel" + + # --- No attention set --- + + def test_no_attention_set_is_noop(self): + result = self._normalize({"some_other_config": True}) + assert result.get("attn_implementation") is None + + # --- Sample packing interactions --- + + def test_xformers_with_sample_packing_sets_flash(self): + """xformers + sample_packing needs flash_attention=True for the patch chain.""" + result = self._normalize( + {"attn_implementation": "xformers", "sample_packing": True} + ) + assert result["xformers_attention"] is True + assert result["flash_attention"] is True + + +class TestAttnImplToHFMapping: + """Test that attn_implementation enum values map correctly to HF strings.""" + + # This dict mirrors _ATTN_IMPL_TO_HF in model.py + _ATTN_IMPL_TO_HF = { + "eager": "eager", + "flash": "flash_attention_2", + "sdpa": "sdpa", + "xformers": "xformers", + "flex": "flex_attention", + "sage": "sage", + "s2": "flash_attention_2", + "fp8": "sdpa", + } + + @pytest.mark.parametrize( + "impl,expected_hf", + [ + ("eager", "eager"), + ("flash", "flash_attention_2"), + ("sdpa", "sdpa"), + ("xformers", "xformers"), + ("flex", "flex_attention"), + ("sage", "sage"), + ("s2", "flash_attention_2"), + ("fp8", "sdpa"), + ], + ) + def test_known_impl_maps_correctly(self, impl, expected_hf): + assert self._ATTN_IMPL_TO_HF[impl] == expected_hf + + def test_hub_kernel_falls_through(self): + """Hub kernel strings should pass through .get() unchanged.""" + hub_str = "kernels-community/flash-attn3" + result = self._ATTN_IMPL_TO_HF.get(hub_str, hub_str) + assert result == hub_str + + +def _xformers_available(): + try: + import xformers.ops # noqa: F401 + + return True + except (ImportError, OSError): + return False + + +class TestAttentionRegistration: + """Test that attention backends register correctly in HF's registries.""" + + @pytest.mark.skipif(not _xformers_available(), reason="xformers not available") + def test_register_xformers(self): + from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + from axolotl.monkeypatch.attention import register_xformers_attn + + register_xformers_attn() + + assert "xformers" in ALL_ATTENTION_FUNCTIONS + assert "xformers" in ALL_MASK_ATTENTION_FUNCTIONS + # xformers mask should be the same function as flash_attention_2's mask + assert ( + ALL_MASK_ATTENTION_FUNCTIONS["xformers"] + == ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"] + ) + + def test_register_sage(self): + from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + from axolotl.monkeypatch.attention import register_sage_attn + + register_sage_attn() + + assert "sage" in ALL_ATTENTION_FUNCTIONS + assert "sage" in ALL_MASK_ATTENTION_FUNCTIONS + assert ( + ALL_MASK_ATTENTION_FUNCTIONS["sage"] + == ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"] + ) + + @pytest.mark.skipif(not _xformers_available(), reason="xformers not available") + def test_xformers_does_not_overwrite_fa2(self): + """Registering xformers should not modify the flash_attention_2 slot.""" + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + original_fa2 = ALL_ATTENTION_FUNCTIONS["flash_attention_2"] + + from axolotl.monkeypatch.attention import register_xformers_attn + + register_xformers_attn() + + assert ALL_ATTENTION_FUNCTIONS["flash_attention_2"] is original_fa2 + + def test_sage_does_not_overwrite_fa2(self): + """Registering sage should not modify the flash_attention_2 slot.""" + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + original_fa2 = ALL_ATTENTION_FUNCTIONS["flash_attention_2"] + + from axolotl.monkeypatch.attention import register_sage_attn + + register_sage_attn() + + assert ALL_ATTENTION_FUNCTIONS["flash_attention_2"] is original_fa2