"""Tests for attn_implementation: normalization, canonical-value acceptance, capability flags, backend registration, and downstream validators. """ import logging from contextlib import contextmanager import pytest from axolotl.utils.config import validate_config from axolotl.utils.dict import DictDefault from axolotl.utils.schemas.config import AxolotlInputConfig from axolotl.utils.schemas.enums import ( ATTN_IMPLS_SUPPORTING_PACKING, ATTN_IMPLS_USING_FLASH_LIB, ATTN_IMPLS_WITHOUT_DTYPE_CAST, CANONICAL_ATTN_IMPLS, ) @contextmanager def _capture_axolotl_warnings(caplog): """Capture WARNINGs from `axolotl.*` loggers via caplog. `axolotl.cli` calls `configure_logging()` at import time, which sets `propagate=False` on the `axolotl` logger so records do not reach the root logger that pytest's `caplog` hooks. This helper temporarily re-enables propagation for the duration of the block. """ ax_logger = logging.getLogger("axolotl") old_propagate = ax_logger.propagate ax_logger.propagate = True try: with caplog.at_level(logging.WARNING, logger="axolotl"): yield finally: ax_logger.propagate = old_propagate def _xformers_available(): try: import xformers.ops # noqa: F401 return True except (ImportError, OSError): return False class TestCapabilityTables: """Backend capability classification via frozensets and computed_field properties.""" @pytest.mark.parametrize( "impl", [ "flash_attention_2", "flash_attention_3", "flex_attention", "xformers", "sage", ], ) def test_supports_packing(self, impl): assert impl in ATTN_IMPLS_SUPPORTING_PACKING @pytest.mark.parametrize("impl", ["eager", "sdpa", "s2", "fp8"]) def test_does_not_support_packing(self, impl): assert impl not in ATTN_IMPLS_SUPPORTING_PACKING @pytest.mark.parametrize("impl", ["flash_attention_2", "flash_attention_3", "s2"]) def test_uses_flash_lib(self, impl): assert impl in ATTN_IMPLS_USING_FLASH_LIB @pytest.mark.parametrize( "impl", ["eager", "sdpa", "xformers", "flex_attention", "sage", "fp8"] ) def test_does_not_use_flash_lib(self, impl): assert impl not in ATTN_IMPLS_USING_FLASH_LIB @pytest.mark.parametrize("impl", ["eager", "sdpa"]) def test_no_dtype_cast(self, impl): assert impl in ATTN_IMPLS_WITHOUT_DTYPE_CAST @pytest.mark.parametrize( "impl", [ "flash_attention_2", "flash_attention_3", "flex_attention", "xformers", "sage", "s2", "fp8", ], ) def test_needs_dtype_cast(self, impl): assert impl not in ATTN_IMPLS_WITHOUT_DTYPE_CAST def test_known_hub_kernels_classified(self): assert "kernels-community/flash-attn3" in ATTN_IMPLS_SUPPORTING_PACKING assert "kernels-community/flash-attn3" in ATTN_IMPLS_USING_FLASH_LIB assert "kernels-community/sage-attention" in ATTN_IMPLS_SUPPORTING_PACKING def test_computed_flags_readable_on_validated_cfg(self, min_base_cfg): cfg = min_base_cfg | DictDefault(attn_implementation="sdpa") validated = validate_config(cfg) assert validated.attn_implementation == "sdpa" assert validated.attn_supports_packing is False assert validated.attn_uses_flash_lib is False assert validated.attn_needs_dtype_cast is False def test_computed_flags_not_overridable_from_yaml(self, min_base_cfg): """YAML attempts to override a computed field must not win.""" cfg = min_base_cfg | DictDefault( attn_implementation="eager", attn_uses_flash_lib=True ) validated = validate_config(cfg) # The computed field reflects the backend, not the YAML input. assert validated.attn_uses_flash_lib is False class TestBackendRegistration: """Axolotl-owned backends register under their canonical names in HF's registries.""" @pytest.mark.skipif(not _xformers_available(), reason="xformers not available") def test_register_xformers(self): from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from axolotl.monkeypatch.attention import register_xformers_attn register_xformers_attn() assert "xformers" in ALL_ATTENTION_FUNCTIONS assert "xformers" in ALL_MASK_ATTENTION_FUNCTIONS assert ( ALL_MASK_ATTENTION_FUNCTIONS["xformers"] == ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"] ) def test_register_sage(self): from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from axolotl.monkeypatch.attention import register_sage_attn register_sage_attn() assert "sage" in ALL_ATTENTION_FUNCTIONS assert "sage" in ALL_MASK_ATTENTION_FUNCTIONS assert ( ALL_MASK_ATTENTION_FUNCTIONS["sage"] == ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"] ) @pytest.mark.skipif(not _xformers_available(), reason="xformers not available") def test_xformers_does_not_overwrite_fa2(self): from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS original_fa2 = ALL_ATTENTION_FUNCTIONS["flash_attention_2"] from axolotl.monkeypatch.attention import register_xformers_attn register_xformers_attn() assert ALL_ATTENTION_FUNCTIONS["flash_attention_2"] is original_fa2 def test_sage_does_not_overwrite_fa2(self): from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS original_fa2 = ALL_ATTENTION_FUNCTIONS["flash_attention_2"] from axolotl.monkeypatch.attention import register_sage_attn register_sage_attn() assert ALL_ATTENTION_FUNCTIONS["flash_attention_2"] is original_fa2 class TestLegacyFlagDeprecation: """Legacy boolean flags (flash_attention, sdp_attention, ...) map to a canonical attn_implementation value, are stripped from the validated config, and cannot be combined with an explicit canonical value. """ @staticmethod def _normalize(data): return AxolotlInputConfig.normalize_attn_implementation(data) @pytest.mark.parametrize( "flag,expected", [ ("flash_attention", "flash_attention_2"), ("sdp_attention", "sdpa"), ("xformers_attention", "xformers"), ("flex_attention", "flex_attention"), ("sage_attention", "sage"), ("eager_attention", "eager"), ("s2_attention", "s2"), ], ) def test_legacy_flag_maps_to_canonical(self, flag, expected): result = self._normalize({flag: True}) assert result["attn_implementation"] == expected def test_legacy_flags_are_stripped_after_mapping(self): result = self._normalize({"flash_attention": True}) for flag in [ "flash_attention", "sdp_attention", "xformers_attention", "flex_attention", "sage_attention", "eager_attention", "s2_attention", ]: assert flag not in result def test_s2_plus_flash_priority_is_s2(self): result = self._normalize({"s2_attention": True, "flash_attention": True}) assert result["attn_implementation"] == "s2" def test_sage_plus_flash_priority_is_sage(self): result = self._normalize({"sage_attention": True, "flash_attention": True}) assert result["attn_implementation"] == "sage" def test_canonical_plus_legacy_flag_raises(self): with pytest.raises(ValueError, match="cannot be combined with legacy"): self._normalize( {"attn_implementation": "flash_attention_2", "flash_attention": True} ) def test_canonical_plus_unrelated_legacy_flag_raises(self): with pytest.raises(ValueError, match="cannot be combined with legacy"): self._normalize( {"attn_implementation": "xformers", "flash_attention": True} ) def test_legacy_flag_stripped_on_validated_cfg(self, min_base_cfg): cfg = min_base_cfg | DictDefault(flash_attention=True) validated = validate_config(cfg) assert validated.attn_implementation == "flash_attention_2" # Legacy flag must not survive to the validated DictDefault # (normalizer pops it, model_dump excludes Nones). assert "flash_attention" not in dict(validated) def test_canonical_plus_legacy_rejected_on_full_validation(self, min_base_cfg): cfg = min_base_cfg | DictDefault( attn_implementation="flash_attention_2", flash_attention=True ) with pytest.raises(ValueError, match="cannot be combined with legacy"): validate_config(cfg) def test_s2_plus_flash_maps_to_s2_on_full_validation(self, min_base_cfg): """Priority resolution applies through the full validator chain too.""" cfg = min_base_cfg | DictDefault(s2_attention=True, flash_attention=True) validated = validate_config(cfg) assert validated.attn_implementation == "s2" class TestCanonicalValueAcceptance: """`attn_implementation` accepts canonical names and `org/name` hub-kernel paths. Short-form aliases (`flash`, `flex`, `sdp`) and unknown bare names are rejected. Absent input is a noop. """ @staticmethod def _normalize(data): return AxolotlInputConfig.normalize_attn_implementation(data) def test_canonical_value_is_passthrough(self): data = {"attn_implementation": "flash_attention_2"} result = self._normalize(data) assert result["attn_implementation"] == "flash_attention_2" def test_hub_kernel_is_passthrough(self): data = {"attn_implementation": "kernels-community/flash-attn3"} result = self._normalize(data) assert result["attn_implementation"] == "kernels-community/flash-attn3" def test_no_attention_set_is_noop(self): result = self._normalize({"some_other_config": True}) assert result.get("attn_implementation") is None def test_field_validator_accepts_all_canonical(self): for impl in CANONICAL_ATTN_IMPLS: assert AxolotlInputConfig.validate_attn_implementation(impl) == impl def test_field_validator_accepts_hub_kernels(self): for impl in ( "kernels-community/flash-attn3", "kernels-community/sage-attention", "someorg/custom-kernel", ): assert AxolotlInputConfig.validate_attn_implementation(impl) == impl def test_field_validator_accepts_none(self): assert AxolotlInputConfig.validate_attn_implementation(None) is None @pytest.mark.parametrize("alias", ["flash", "flex", "sdp"]) def test_short_form_alias_rejected(self, alias): with pytest.raises(ValueError, match="is not accepted"): AxolotlInputConfig.validate_attn_implementation(alias) def test_unknown_bare_name_rejected(self): with pytest.raises(ValueError, match="not a recognized backend"): AxolotlInputConfig.validate_attn_implementation("not_a_real_backend") def test_canonical_value_passes_through_full_validation(self, min_base_cfg): cfg = min_base_cfg | DictDefault(attn_implementation="flash_attention_3") validated = validate_config(cfg) assert validated.attn_implementation == "flash_attention_3" assert validated.attn_uses_flash_lib is True assert validated.attn_supports_packing is True def test_hub_kernel_passes_through_full_validation(self, min_base_cfg): cfg = min_base_cfg | DictDefault( attn_implementation="kernels-community/flash-attn3" ) validated = validate_config(cfg) assert validated.attn_implementation == "kernels-community/flash-attn3" assert validated.attn_uses_flash_lib is True assert validated.attn_supports_packing is True def test_short_form_alias_rejected_on_full_validation(self, min_base_cfg): cfg = min_base_cfg | DictDefault(attn_implementation="flash") with pytest.raises(ValueError, match="is not accepted"): validate_config(cfg) class TestGemma4HybridMode: """`gemma4_hybrid_attn_impl` pins `attn_implementation` to `flash_attention_2`.""" @staticmethod def _normalize(data): return AxolotlInputConfig.normalize_attn_implementation(data) def test_defaults_to_flash_attention_2(self): result = self._normalize({"gemma4_hybrid_attn_impl": True}) assert result["attn_implementation"] == "flash_attention_2" def test_explicit_fa2_passes(self): result = self._normalize( { "gemma4_hybrid_attn_impl": True, "attn_implementation": "flash_attention_2", } ) assert result["attn_implementation"] == "flash_attention_2" def test_non_fa2_raises(self): with pytest.raises( ValueError, match="requires attn_implementation=flash_attention_2" ): self._normalize( {"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"} ) class TestSamplePackingValidation: """`sample_packing` warns for non-varlen backends; s2 raises outright.""" def test_eager_warns(self, min_base_cfg, caplog): cfg = min_base_cfg | DictDefault( attn_implementation="eager", sample_packing=True ) with _capture_axolotl_warnings(caplog): validate_config(cfg) assert any( "does not handle cross-sample decontamination" in r.getMessage() for r in caplog.records ) def test_sdpa_warns(self, min_base_cfg, caplog): cfg = min_base_cfg | DictDefault( attn_implementation="sdpa", sample_packing=True ) with _capture_axolotl_warnings(caplog): validate_config(cfg) assert any( "does not handle cross-sample decontamination" in r.getMessage() for r in caplog.records ) def test_flash_attention_2_does_not_warn(self, min_base_cfg, caplog): cfg = min_base_cfg | DictDefault( attn_implementation="flash_attention_2", sample_packing=True ) with _capture_axolotl_warnings(caplog): validate_config(cfg) assert not any( "does not handle cross-sample decontamination" in r.getMessage() for r in caplog.records ) def test_s2_raises(self, min_base_cfg): cfg = min_base_cfg | DictDefault(attn_implementation="s2", sample_packing=True) with pytest.raises( ValueError, match="shifted-sparse attention does not currently support" ): validate_config(cfg) class TestScalingSoftmaxValidation: """`scaling_softmax` is only implemented under flex_attention.""" def test_non_flex_raises(self, min_base_cfg): cfg = min_base_cfg | DictDefault( attn_implementation="flash_attention_2", scaling_softmax=True ) with pytest.raises(ValueError, match="scaling_softmax requires flex"): validate_config(cfg) def test_flex_passes(self, min_base_cfg): cfg = min_base_cfg | DictDefault( attn_implementation="flex_attention", scaling_softmax=True ) validated = validate_config(cfg) assert validated.attn_implementation == "flex_attention"