Files
axolotl/tests/test_attn_implementation.py

436 lines
16 KiB
Python

"""Tests for attn_implementation: input normalization, canonical-value
acceptance, capability flags, backend registration, and downstream validators.
Test classes are organized by feature concern, not by the layer of the schema
where the behavior is implemented (classmethod normalizer vs. field validator
vs. full `validate_config` pipeline). Each class covers a single contract end
to end, dropping into the lower layer only where it gives faster or sharper
coverage of an isolated branch.
"""
import logging
from contextlib import contextmanager
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.config import AxolotlInputConfig
from axolotl.utils.schemas.enums import (
ATTN_IMPLS_SUPPORTING_PACKING,
ATTN_IMPLS_USING_FLASH_LIB,
ATTN_IMPLS_WITHOUT_DTYPE_CAST,
CANONICAL_ATTN_IMPLS,
)
@contextmanager
def _capture_axolotl_warnings(caplog):
"""Capture WARNINGs from `axolotl.*` loggers via caplog.
`axolotl.cli` calls `configure_logging()` at import time, which sets
`propagate=False` on the `axolotl` logger so records do not reach the root
logger that pytest's `caplog` hooks. This helper temporarily re-enables
propagation for the duration of the block.
"""
ax_logger = logging.getLogger("axolotl")
old_propagate = ax_logger.propagate
ax_logger.propagate = True
try:
with caplog.at_level(logging.WARNING, logger="axolotl"):
yield
finally:
ax_logger.propagate = old_propagate
def _xformers_available():
try:
import xformers.ops # noqa: F401
return True
except (ImportError, OSError):
return False
class TestCapabilityTables:
"""Backend capability classification.
Asserts both the static frozensets in `enums.py` and the `computed_field`
properties on a validated config read consistently from those tables, and
that user YAML cannot override the computed flags.
"""
@pytest.mark.parametrize(
"impl",
[
"flash_attention_2",
"flash_attention_3",
"flex_attention",
"xformers",
"sage",
],
)
def test_supports_packing(self, impl):
assert impl in ATTN_IMPLS_SUPPORTING_PACKING
@pytest.mark.parametrize("impl", ["eager", "sdpa", "s2", "fp8"])
def test_does_not_support_packing(self, impl):
assert impl not in ATTN_IMPLS_SUPPORTING_PACKING
@pytest.mark.parametrize("impl", ["flash_attention_2", "flash_attention_3", "s2"])
def test_uses_flash_lib(self, impl):
assert impl in ATTN_IMPLS_USING_FLASH_LIB
@pytest.mark.parametrize(
"impl", ["eager", "sdpa", "xformers", "flex_attention", "sage", "fp8"]
)
def test_does_not_use_flash_lib(self, impl):
assert impl not in ATTN_IMPLS_USING_FLASH_LIB
@pytest.mark.parametrize("impl", ["eager", "sdpa"])
def test_no_dtype_cast(self, impl):
assert impl in ATTN_IMPLS_WITHOUT_DTYPE_CAST
@pytest.mark.parametrize(
"impl",
[
"flash_attention_2",
"flash_attention_3",
"flex_attention",
"xformers",
"sage",
"s2",
"fp8",
],
)
def test_needs_dtype_cast(self, impl):
assert impl not in ATTN_IMPLS_WITHOUT_DTYPE_CAST
def test_known_hub_kernels_classified(self):
assert "kernels-community/flash-attn3" in ATTN_IMPLS_SUPPORTING_PACKING
assert "kernels-community/flash-attn3" in ATTN_IMPLS_USING_FLASH_LIB
assert "kernels-community/sage-attention" in ATTN_IMPLS_SUPPORTING_PACKING
def test_computed_flags_readable_on_validated_cfg(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(attn_implementation="sdpa")
validated = validate_config(cfg)
assert validated.attn_implementation == "sdpa"
assert validated.attn_supports_packing is False
assert validated.attn_uses_flash_lib is False
assert validated.attn_needs_dtype_cast is False
def test_computed_flags_not_overridable_from_yaml(self, min_base_cfg):
"""YAML attempts to override a computed field must not win."""
cfg = min_base_cfg | DictDefault(
attn_implementation="eager", attn_uses_flash_lib=True
)
validated = validate_config(cfg)
# The computed field reflects the backend, not the YAML input.
assert validated.attn_uses_flash_lib is False
class TestBackendRegistration:
"""Axolotl-owned backends register under their canonical names in HF's registries."""
@pytest.mark.skipif(not _xformers_available(), reason="xformers not available")
def test_register_xformers(self):
from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.monkeypatch.attention import register_xformers_attn
register_xformers_attn()
assert "xformers" in ALL_ATTENTION_FUNCTIONS
assert "xformers" in ALL_MASK_ATTENTION_FUNCTIONS
assert (
ALL_MASK_ATTENTION_FUNCTIONS["xformers"]
== ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
)
def test_register_sage(self):
from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.monkeypatch.attention import register_sage_attn
register_sage_attn()
assert "sage" in ALL_ATTENTION_FUNCTIONS
assert "sage" in ALL_MASK_ATTENTION_FUNCTIONS
assert (
ALL_MASK_ATTENTION_FUNCTIONS["sage"]
== ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
)
@pytest.mark.skipif(not _xformers_available(), reason="xformers not available")
def test_xformers_does_not_overwrite_fa2(self):
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
original_fa2 = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]
from axolotl.monkeypatch.attention import register_xformers_attn
register_xformers_attn()
assert ALL_ATTENTION_FUNCTIONS["flash_attention_2"] is original_fa2
def test_sage_does_not_overwrite_fa2(self):
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
original_fa2 = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]
from axolotl.monkeypatch.attention import register_sage_attn
register_sage_attn()
assert ALL_ATTENTION_FUNCTIONS["flash_attention_2"] is original_fa2
class TestLegacyFlagDeprecation:
"""Legacy boolean flags (flash_attention, sdp_attention, ...) map to a
canonical attn_implementation value, are stripped from the validated
config, and cannot be combined with an explicit canonical value.
"""
@staticmethod
def _normalize(data):
return AxolotlInputConfig.normalize_attn_implementation(data)
@pytest.mark.parametrize(
"flag,expected",
[
("flash_attention", "flash_attention_2"),
("sdp_attention", "sdpa"),
("xformers_attention", "xformers"),
("flex_attention", "flex_attention"),
("sage_attention", "sage"),
("eager_attention", "eager"),
("s2_attention", "s2"),
],
)
def test_legacy_flag_maps_to_canonical(self, flag, expected):
result = self._normalize({flag: True})
assert result["attn_implementation"] == expected
def test_legacy_flags_are_stripped_after_mapping(self):
result = self._normalize({"flash_attention": True})
for flag in [
"flash_attention",
"sdp_attention",
"xformers_attention",
"flex_attention",
"sage_attention",
"eager_attention",
"s2_attention",
]:
assert flag not in result
def test_s2_plus_flash_priority_is_s2(self):
result = self._normalize({"s2_attention": True, "flash_attention": True})
assert result["attn_implementation"] == "s2"
def test_sage_plus_flash_priority_is_sage(self):
result = self._normalize({"sage_attention": True, "flash_attention": True})
assert result["attn_implementation"] == "sage"
def test_canonical_plus_legacy_flag_raises(self):
with pytest.raises(ValueError, match="cannot be combined with legacy"):
self._normalize(
{"attn_implementation": "flash_attention_2", "flash_attention": True}
)
def test_canonical_plus_unrelated_legacy_flag_raises(self):
with pytest.raises(ValueError, match="cannot be combined with legacy"):
self._normalize(
{"attn_implementation": "xformers", "flash_attention": True}
)
def test_legacy_flag_stripped_on_validated_cfg(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(flash_attention=True)
validated = validate_config(cfg)
assert validated.attn_implementation == "flash_attention_2"
# Legacy flag must not survive to the validated DictDefault
# (normalizer pops it, model_dump excludes Nones).
assert "flash_attention" not in dict(validated)
def test_canonical_plus_legacy_rejected_on_full_validation(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
attn_implementation="flash_attention_2", flash_attention=True
)
with pytest.raises(ValueError, match="cannot be combined with legacy"):
validate_config(cfg)
def test_s2_plus_flash_maps_to_s2_on_full_validation(self, min_base_cfg):
"""Priority resolution applies through the full validator chain too."""
cfg = min_base_cfg | DictDefault(s2_attention=True, flash_attention=True)
validated = validate_config(cfg)
assert validated.attn_implementation == "s2"
class TestCanonicalValueAcceptance:
"""`attn_implementation` accepts canonical names and `org/name` hub-kernel
paths. Short-form aliases (`flash`, `flex`, `sdp`) and unknown bare names
are rejected. Absent input is a noop.
"""
@staticmethod
def _normalize(data):
return AxolotlInputConfig.normalize_attn_implementation(data)
def test_canonical_value_is_passthrough(self):
data = {"attn_implementation": "flash_attention_2"}
result = self._normalize(data)
assert result["attn_implementation"] == "flash_attention_2"
def test_hub_kernel_is_passthrough(self):
data = {"attn_implementation": "kernels-community/flash-attn3"}
result = self._normalize(data)
assert result["attn_implementation"] == "kernels-community/flash-attn3"
def test_no_attention_set_is_noop(self):
result = self._normalize({"some_other_config": True})
assert result.get("attn_implementation") is None
def test_field_validator_accepts_all_canonical(self):
for impl in CANONICAL_ATTN_IMPLS:
assert AxolotlInputConfig.validate_attn_implementation(impl) == impl
def test_field_validator_accepts_hub_kernels(self):
for impl in (
"kernels-community/flash-attn3",
"kernels-community/sage-attention",
"someorg/custom-kernel",
):
assert AxolotlInputConfig.validate_attn_implementation(impl) == impl
def test_field_validator_accepts_none(self):
assert AxolotlInputConfig.validate_attn_implementation(None) is None
@pytest.mark.parametrize("alias", ["flash", "flex", "sdp"])
def test_short_form_alias_rejected(self, alias):
with pytest.raises(ValueError, match="is not accepted"):
AxolotlInputConfig.validate_attn_implementation(alias)
def test_unknown_bare_name_rejected(self):
with pytest.raises(ValueError, match="not a recognized backend"):
AxolotlInputConfig.validate_attn_implementation("not_a_real_backend")
def test_canonical_value_passes_through_full_validation(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(attn_implementation="flash_attention_3")
validated = validate_config(cfg)
assert validated.attn_implementation == "flash_attention_3"
assert validated.attn_uses_flash_lib is True
assert validated.attn_supports_packing is True
def test_hub_kernel_passes_through_full_validation(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
attn_implementation="kernels-community/flash-attn3"
)
validated = validate_config(cfg)
assert validated.attn_implementation == "kernels-community/flash-attn3"
assert validated.attn_uses_flash_lib is True
assert validated.attn_supports_packing is True
def test_short_form_alias_rejected_on_full_validation(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(attn_implementation="flash")
with pytest.raises(ValueError, match="is not accepted"):
validate_config(cfg)
class TestGemma4HybridMode:
"""`gemma4_hybrid_attn_impl` pins `attn_implementation` to `flash_attention_2`."""
@staticmethod
def _normalize(data):
return AxolotlInputConfig.normalize_attn_implementation(data)
def test_defaults_to_flash_attention_2(self):
result = self._normalize({"gemma4_hybrid_attn_impl": True})
assert result["attn_implementation"] == "flash_attention_2"
def test_explicit_fa2_passes(self):
result = self._normalize(
{
"gemma4_hybrid_attn_impl": True,
"attn_implementation": "flash_attention_2",
}
)
assert result["attn_implementation"] == "flash_attention_2"
def test_non_fa2_raises(self):
"""The hybrid path requires FA2 under the hood — any other backend is
a configuration error."""
with pytest.raises(
ValueError, match="requires attn_implementation=flash_attention_2"
):
self._normalize(
{"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"}
)
class TestSamplePackingValidation:
"""`sample_packing` requires a varlen-capable backend.
Non-varlen backends (eager, sdpa) warn about cross-sample contamination;
s2 raises outright because shifted-sparse attention has no varlen path.
"""
def test_eager_warns(self, min_base_cfg, caplog):
cfg = min_base_cfg | DictDefault(
attn_implementation="eager", sample_packing=True
)
with _capture_axolotl_warnings(caplog):
validate_config(cfg)
assert any(
"does not handle cross-sample decontamination" in r.getMessage()
for r in caplog.records
)
def test_sdpa_warns(self, min_base_cfg, caplog):
cfg = min_base_cfg | DictDefault(
attn_implementation="sdpa", sample_packing=True
)
with _capture_axolotl_warnings(caplog):
validate_config(cfg)
assert any(
"does not handle cross-sample decontamination" in r.getMessage()
for r in caplog.records
)
def test_flash_attention_2_does_not_warn(self, min_base_cfg, caplog):
cfg = min_base_cfg | DictDefault(
attn_implementation="flash_attention_2", sample_packing=True
)
with _capture_axolotl_warnings(caplog):
validate_config(cfg)
assert not any(
"does not handle cross-sample decontamination" in r.getMessage()
for r in caplog.records
)
def test_s2_raises(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(attn_implementation="s2", sample_packing=True)
with pytest.raises(
ValueError, match="shifted-sparse attention does not currently support"
):
validate_config(cfg)
class TestScalingSoftmaxValidation:
"""`scaling_softmax` is only implemented under flex_attention."""
def test_non_flex_raises(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
attn_implementation="flash_attention_2", scaling_softmax=True
)
with pytest.raises(ValueError, match="scaling_softmax requires flex"):
validate_config(cfg)
def test_flex_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
attn_implementation="flex_attention", scaling_softmax=True
)
validated = validate_config(cfg)
assert validated.attn_implementation == "flex_attention"