diff --git a/tests/test_attn_implementation.py b/tests/test_attn_implementation.py index 50e0f651a..4c6495a2a 100644 --- a/tests/test_attn_implementation.py +++ b/tests/test_attn_implementation.py @@ -1,11 +1,11 @@ -"""Tests for attn_implementation normalization and capability computation. +"""Tests for attn_implementation: input normalization, canonical-value +acceptance, capability flags, backend registration, and downstream validators. -Covers the contract: -- `attn_implementation` accepts canonical names only; short-form aliases are rejected. -- Legacy boolean flags are mapped to the canonical value, warned on, and stripped. -- Canonical `attn_implementation` + legacy flag raises. -- Capability flags are computed from `attn_implementation`. -- Validator behaviour for sample_packing / scaling_softmax / s2 / fp8 / hub kernels. +Test classes are organized by feature concern, not by the layer of the schema +where the behavior is implemented (classmethod normalizer vs. field validator +vs. full `validate_config` pipeline). Each class covers a single contract end +to end, dropping into the lower layer only where it gives faster or sharper +coverage of an isolated branch. """ import logging @@ -43,150 +43,22 @@ def _capture_axolotl_warnings(caplog): ax_logger.propagate = old_propagate -class TestNormalizerLegacyMapping: - """Legacy boolean flags map to canonical attn_implementation.""" +def _xformers_available(): + try: + import xformers.ops # noqa: F401 - @staticmethod - def _normalize(data): - return AxolotlInputConfig.normalize_attn_implementation(data) - - @pytest.mark.parametrize( - "flag,expected", - [ - ("flash_attention", "flash_attention_2"), - ("sdp_attention", "sdpa"), - ("xformers_attention", "xformers"), - ("flex_attention", "flex_attention"), - ("sage_attention", "sage"), - ("eager_attention", "eager"), - ("s2_attention", "s2"), - ], - ) - def test_legacy_flag_maps_to_canonical(self, flag, expected): - result = self._normalize({flag: True}) - assert result["attn_implementation"] == expected - - def test_legacy_flags_are_stripped_after_mapping(self): - result = self._normalize({"flash_attention": True}) - for flag in [ - "flash_attention", - "sdp_attention", - "xformers_attention", - "flex_attention", - "sage_attention", - "eager_attention", - "s2_attention", - ]: - assert flag not in result - - def test_s2_plus_flash_priority_is_s2(self): - result = self._normalize({"s2_attention": True, "flash_attention": True}) - assert result["attn_implementation"] == "s2" - - def test_sage_plus_flash_priority_is_sage(self): - result = self._normalize({"sage_attention": True, "flash_attention": True}) - assert result["attn_implementation"] == "sage" - - -class TestNormalizerConflicts: - """Canonical attn_implementation + legacy flag raises.""" - - @staticmethod - def _normalize(data): - return AxolotlInputConfig.normalize_attn_implementation(data) - - def test_canonical_plus_legacy_flag_raises(self): - with pytest.raises(ValueError, match="cannot be combined with legacy"): - self._normalize( - {"attn_implementation": "flash_attention_2", "flash_attention": True} - ) - - def test_canonical_plus_unrelated_legacy_flag_raises(self): - with pytest.raises(ValueError, match="cannot be combined with legacy"): - self._normalize( - {"attn_implementation": "xformers", "flash_attention": True} - ) - - -class TestNormalizerPassthrough: - """Canonical values and hub-kernel paths pass through.""" - - def test_canonical_no_legacy_is_noop(self): - data = {"attn_implementation": "flash_attention_2"} - result = AxolotlInputConfig.normalize_attn_implementation(data) - assert result["attn_implementation"] == "flash_attention_2" - - def test_hub_kernel_passes_through(self): - data = {"attn_implementation": "kernels-community/flash-attn3"} - result = AxolotlInputConfig.normalize_attn_implementation(data) - assert result["attn_implementation"] == "kernels-community/flash-attn3" - - def test_no_attention_set_is_noop(self): - result = AxolotlInputConfig.normalize_attn_implementation( - {"some_other_config": True} - ) - assert result.get("attn_implementation") is None - - -class TestGemma4Hybrid: - """gemma4_hybrid_attn_impl defaults to flash_attention_2.""" - - def test_gemma4_hybrid_defaults_to_fa2(self): - result = AxolotlInputConfig.normalize_attn_implementation( - {"gemma4_hybrid_attn_impl": True} - ) - assert result["attn_implementation"] == "flash_attention_2" - - def test_gemma4_hybrid_with_incompatible_impl_raises(self): - """Setting gemma4_hybrid alongside a non-FA2 attn_implementation is a - configuration error — the hybrid path requires FA2 under the hood.""" - with pytest.raises( - ValueError, match="requires attn_implementation=flash_attention_2" - ): - AxolotlInputConfig.normalize_attn_implementation( - {"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"} - ) - - def test_gemma4_hybrid_with_explicit_fa2_passes(self): - result = AxolotlInputConfig.normalize_attn_implementation( - { - "gemma4_hybrid_attn_impl": True, - "attn_implementation": "flash_attention_2", - } - ) - assert result["attn_implementation"] == "flash_attention_2" - - -class TestFieldValidator: - """attn_implementation field_validator rejects short-form aliases.""" - - def test_canonical_accepted(self): - for impl in CANONICAL_ATTN_IMPLS: - assert AxolotlInputConfig.validate_attn_implementation(impl) == impl - - def test_hub_kernel_accepted(self): - for impl in ( - "kernels-community/flash-attn3", - "kernels-community/sage-attention", - "someorg/custom-kernel", - ): - assert AxolotlInputConfig.validate_attn_implementation(impl) == impl - - def test_none_accepted(self): - assert AxolotlInputConfig.validate_attn_implementation(None) is None - - @pytest.mark.parametrize("alias", ["flash", "flex", "sdp"]) - def test_short_form_alias_rejected(self, alias): - with pytest.raises(ValueError, match="is not accepted"): - AxolotlInputConfig.validate_attn_implementation(alias) - - def test_unknown_without_slash_rejected(self): - with pytest.raises(ValueError, match="not a recognized backend"): - AxolotlInputConfig.validate_attn_implementation("not_a_real_backend") + return True + except (ImportError, OSError): + return False class TestCapabilityTables: - """Capability tables are keyed by canonical names and cover the expected backends.""" + """Backend capability classification. + + Asserts both the static frozensets in `enums.py` and the `computed_field` + properties on a validated config read consistently from those tables, and + that user YAML cannot override the computed flags. + """ @pytest.mark.parametrize( "impl", @@ -239,17 +111,25 @@ class TestCapabilityTables: assert "kernels-community/flash-attn3" in ATTN_IMPLS_USING_FLASH_LIB assert "kernels-community/sage-attention" in ATTN_IMPLS_SUPPORTING_PACKING + def test_computed_flags_readable_on_validated_cfg(self, min_base_cfg): + cfg = min_base_cfg | DictDefault(attn_implementation="sdpa") + validated = validate_config(cfg) + assert validated.attn_implementation == "sdpa" + assert validated.attn_supports_packing is False + assert validated.attn_uses_flash_lib is False + assert validated.attn_needs_dtype_cast is False -def _xformers_available(): - try: - import xformers.ops # noqa: F401 - - return True - except (ImportError, OSError): - return False + def test_computed_flags_not_overridable_from_yaml(self, min_base_cfg): + """YAML attempts to override a computed field must not win.""" + cfg = min_base_cfg | DictDefault( + attn_implementation="eager", attn_uses_flash_lib=True + ) + validated = validate_config(cfg) + # The computed field reflects the backend, not the YAML input. + assert validated.attn_uses_flash_lib is False -class TestAttentionRegistration: +class TestBackendRegistration: """Axolotl-owned backends register under their canonical names in HF's registries.""" @pytest.mark.skipif(not _xformers_available(), reason="xformers not available") @@ -307,16 +187,65 @@ class TestAttentionRegistration: assert ALL_ATTENTION_FUNCTIONS["flash_attention_2"] is original_fa2 -class TestValidatedConfig: - """Exercise the full validator chain on `AxolotlInputConfig(**data)`. - - Classmethod tests above cover the normalizer in isolation. These tests - verify that `model_validator(mode="before")` ordering works under the real - MRO chain — specifically that legacy flags are stripped, the computed - capability fields are readable on the validated instance, and - `attn_supports_packing`/`attn_uses_flash_lib` aren't overridable from YAML. +class TestLegacyFlagDeprecation: + """Legacy boolean flags (flash_attention, sdp_attention, ...) map to a + canonical attn_implementation value, are stripped from the validated + config, and cannot be combined with an explicit canonical value. """ + @staticmethod + def _normalize(data): + return AxolotlInputConfig.normalize_attn_implementation(data) + + @pytest.mark.parametrize( + "flag,expected", + [ + ("flash_attention", "flash_attention_2"), + ("sdp_attention", "sdpa"), + ("xformers_attention", "xformers"), + ("flex_attention", "flex_attention"), + ("sage_attention", "sage"), + ("eager_attention", "eager"), + ("s2_attention", "s2"), + ], + ) + def test_legacy_flag_maps_to_canonical(self, flag, expected): + result = self._normalize({flag: True}) + assert result["attn_implementation"] == expected + + def test_legacy_flags_are_stripped_after_mapping(self): + result = self._normalize({"flash_attention": True}) + for flag in [ + "flash_attention", + "sdp_attention", + "xformers_attention", + "flex_attention", + "sage_attention", + "eager_attention", + "s2_attention", + ]: + assert flag not in result + + def test_s2_plus_flash_priority_is_s2(self): + result = self._normalize({"s2_attention": True, "flash_attention": True}) + assert result["attn_implementation"] == "s2" + + def test_sage_plus_flash_priority_is_sage(self): + result = self._normalize({"sage_attention": True, "flash_attention": True}) + assert result["attn_implementation"] == "sage" + + def test_canonical_plus_legacy_flag_raises(self): + with pytest.raises(ValueError, match="cannot be combined with legacy"): + self._normalize( + {"attn_implementation": "flash_attention_2", "flash_attention": True} + ) + + def test_canonical_plus_unrelated_legacy_flag_raises(self): + with pytest.raises(ValueError, match="cannot be combined with legacy"): + self._normalize( + {"attn_implementation": "xformers", "flash_attention": True} + ) + def test_legacy_flag_stripped_on_validated_cfg(self, min_base_cfg): cfg = min_base_cfg | DictDefault(flash_attention=True) validated = validate_config(cfg) @@ -325,35 +254,6 @@ class TestValidatedConfig: # (normalizer pops it, model_dump excludes Nones). assert "flash_attention" not in dict(validated) - def test_canonical_name_passes_through(self, min_base_cfg): - cfg = min_base_cfg | DictDefault(attn_implementation="flash_attention_3") - validated = validate_config(cfg) - assert validated.attn_implementation == "flash_attention_3" - assert validated.attn_uses_flash_lib is True - assert validated.attn_supports_packing is True - - def test_computed_capability_flags_readable(self, min_base_cfg): - cfg = min_base_cfg | DictDefault(attn_implementation="sdpa") - validated = validate_config(cfg) - assert validated.attn_implementation == "sdpa" - assert validated.attn_supports_packing is False - assert validated.attn_uses_flash_lib is False - assert validated.attn_needs_dtype_cast is False - - def test_capability_flags_not_overridable_from_yaml(self, min_base_cfg): - """YAML attempts to override a computed field must not win.""" - cfg = min_base_cfg | DictDefault( - attn_implementation="eager", attn_uses_flash_lib=True - ) - validated = validate_config(cfg) - # The computed field reflects the backend, not the YAML input. - assert validated.attn_uses_flash_lib is False - - def test_short_form_alias_rejected_on_full_validation(self, min_base_cfg): - cfg = min_base_cfg | DictDefault(attn_implementation="flash") - with pytest.raises(ValueError, match="is not accepted"): - validate_config(cfg) - def test_canonical_plus_legacy_rejected_on_full_validation(self, min_base_cfg): cfg = min_base_cfg | DictDefault( attn_implementation="flash_attention_2", flash_attention=True @@ -362,13 +262,68 @@ class TestValidatedConfig: validate_config(cfg) def test_s2_plus_flash_maps_to_s2_on_full_validation(self, min_base_cfg): - """The inherited `check_attention_fields` mixin used to raise here; - after Phase 1 it's removed and the normalizer owns the priority.""" + """Priority resolution applies through the full validator chain too.""" cfg = min_base_cfg | DictDefault(s2_attention=True, flash_attention=True) validated = validate_config(cfg) assert validated.attn_implementation == "s2" - def test_hub_kernel_on_full_validation(self, min_base_cfg): + +class TestCanonicalValueAcceptance: + """`attn_implementation` accepts canonical names and `org/name` hub-kernel + paths. Short-form aliases (`flash`, `flex`, `sdp`) and unknown bare names + are rejected. Absent input is a noop. + """ + + @staticmethod + def _normalize(data): + return AxolotlInputConfig.normalize_attn_implementation(data) + + def test_canonical_value_is_passthrough(self): + data = {"attn_implementation": "flash_attention_2"} + result = self._normalize(data) + assert result["attn_implementation"] == "flash_attention_2" + + def test_hub_kernel_is_passthrough(self): + data = {"attn_implementation": "kernels-community/flash-attn3"} + result = self._normalize(data) + assert result["attn_implementation"] == "kernels-community/flash-attn3" + + def test_no_attention_set_is_noop(self): + result = self._normalize({"some_other_config": True}) + assert result.get("attn_implementation") is None + + def test_field_validator_accepts_all_canonical(self): + for impl in CANONICAL_ATTN_IMPLS: + assert AxolotlInputConfig.validate_attn_implementation(impl) == impl + + def test_field_validator_accepts_hub_kernels(self): + for impl in ( + "kernels-community/flash-attn3", + "kernels-community/sage-attention", + "someorg/custom-kernel", + ): + assert AxolotlInputConfig.validate_attn_implementation(impl) == impl + + def test_field_validator_accepts_none(self): + assert AxolotlInputConfig.validate_attn_implementation(None) is None + + @pytest.mark.parametrize("alias", ["flash", "flex", "sdp"]) + def test_short_form_alias_rejected(self, alias): + with pytest.raises(ValueError, match="is not accepted"): + AxolotlInputConfig.validate_attn_implementation(alias) + + def test_unknown_bare_name_rejected(self): + with pytest.raises(ValueError, match="not a recognized backend"): + AxolotlInputConfig.validate_attn_implementation("not_a_real_backend") + + def test_canonical_value_passes_through_full_validation(self, min_base_cfg): + cfg = min_base_cfg | DictDefault(attn_implementation="flash_attention_3") + validated = validate_config(cfg) + assert validated.attn_implementation == "flash_attention_3" + assert validated.attn_uses_flash_lib is True + assert validated.attn_supports_packing is True + + def test_hub_kernel_passes_through_full_validation(self, min_base_cfg): cfg = min_base_cfg | DictDefault( attn_implementation="kernels-community/flash-attn3" ) @@ -377,11 +332,51 @@ class TestValidatedConfig: assert validated.attn_uses_flash_lib is True assert validated.attn_supports_packing is True + def test_short_form_alias_rejected_on_full_validation(self, min_base_cfg): + cfg = min_base_cfg | DictDefault(attn_implementation="flash") + with pytest.raises(ValueError, match="is not accepted"): + validate_config(cfg) -class TestAttentionValidators: - """Regression tests for sample_packing / scaling_softmax / s2 / fp8 validators.""" - def test_sample_packing_with_eager_warns(self, min_base_cfg, caplog): +class TestGemma4HybridMode: + """`gemma4_hybrid_attn_impl` pins `attn_implementation` to `flash_attention_2`.""" + + @staticmethod + def _normalize(data): + return AxolotlInputConfig.normalize_attn_implementation(data) + + def test_defaults_to_flash_attention_2(self): + result = self._normalize({"gemma4_hybrid_attn_impl": True}) + assert result["attn_implementation"] == "flash_attention_2" + + def test_explicit_fa2_passes(self): + result = self._normalize( + { + "gemma4_hybrid_attn_impl": True, + "attn_implementation": "flash_attention_2", + } + ) + assert result["attn_implementation"] == "flash_attention_2" + + def test_non_fa2_raises(self): + """The hybrid path requires FA2 under the hood — any other backend is + a configuration error.""" + with pytest.raises( + ValueError, match="requires attn_implementation=flash_attention_2" + ): + self._normalize( + {"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"} + ) + + +class TestSamplePackingValidation: + """`sample_packing` requires a varlen-capable backend. + + Non-varlen backends (eager, sdpa) warn about cross-sample contamination; + s2 raises outright because shifted-sparse attention has no varlen path. + """ + + def test_eager_warns(self, min_base_cfg, caplog): cfg = min_base_cfg | DictDefault( attn_implementation="eager", sample_packing=True ) @@ -392,7 +387,7 @@ class TestAttentionValidators: for r in caplog.records ) - def test_sample_packing_with_sdpa_warns(self, min_base_cfg, caplog): + def test_sdpa_warns(self, min_base_cfg, caplog): cfg = min_base_cfg | DictDefault( attn_implementation="sdpa", sample_packing=True ) @@ -403,7 +398,7 @@ class TestAttentionValidators: for r in caplog.records ) - def test_sample_packing_with_flash_does_not_warn(self, min_base_cfg, caplog): + def test_flash_attention_2_does_not_warn(self, min_base_cfg, caplog): cfg = min_base_cfg | DictDefault( attn_implementation="flash_attention_2", sample_packing=True ) @@ -414,21 +409,25 @@ class TestAttentionValidators: for r in caplog.records ) - def test_sample_packing_with_s2_raises(self, min_base_cfg): + def test_s2_raises(self, min_base_cfg): cfg = min_base_cfg | DictDefault(attn_implementation="s2", sample_packing=True) with pytest.raises( ValueError, match="shifted-sparse attention does not currently support" ): validate_config(cfg) - def test_scaling_softmax_without_flex_raises(self, min_base_cfg): + +class TestScalingSoftmaxValidation: + """`scaling_softmax` is only implemented under flex_attention.""" + + def test_non_flex_raises(self, min_base_cfg): cfg = min_base_cfg | DictDefault( attn_implementation="flash_attention_2", scaling_softmax=True ) with pytest.raises(ValueError, match="scaling_softmax requires flex"): validate_config(cfg) - def test_scaling_softmax_with_flex_passes(self, min_base_cfg): + def test_flex_passes(self, min_base_cfg): cfg = min_base_cfg | DictDefault( attn_implementation="flex_attention", scaling_softmax=True )