"""Enforce attn_implementation as the single source of truth. Fails if src/ contains a cfg._attention read. Migrate offending sites to cfg.attn_implementation or the attn_supports_packing/attn_uses_flash_lib/ attn_needs_dtype_cast computed flags. """ from __future__ import annotations import re from pathlib import Path LEGACY_FLAGS = ( "flash_attention", "sdp_attention", "xformers_attention", "flex_attention", "sage_attention", "s2_attention", "eager_attention", ) # The normalizer is allowed to read the legacy keys (that's its job). # lm_eval/cli.py is a raw-YAML entry point (bypasses AxolotlInputConfig) that # honors both forms during the deprecation period — when we remove the legacy # flags entirely, drop this allowlist entry and the BC branch in that file. ALLOWED_FILES = { Path("src/axolotl/utils/schemas/config.py"), Path("src/axolotl/integrations/lm_eval/cli.py"), } # `cfg.`, `self.cfg.`, `data.get("")`, `data[""]` _PATTERNS = [re.compile(rf"\bcfg\.{flag}\b") for flag in LEGACY_FLAGS] + [ re.compile(rf'\bdata\.get\("{flag}"\)') for flag in LEGACY_FLAGS ] def _repo_root() -> Path: return Path(__file__).resolve().parent.parent def test_no_legacy_attn_reads_in_src(): root = _repo_root() src = root / "src" offenders: list[str] = [] for py_file in src.rglob("*.py"): rel = py_file.relative_to(root) if rel in ALLOWED_FILES: continue text = py_file.read_text(encoding="utf-8") for pattern in _PATTERNS: for match in pattern.finditer(text): # Line number for the user's convenience. line_no = text.count("\n", 0, match.start()) + 1 offenders.append(f"{rel}:{line_no} {match.group(0)}") assert not offenders, ( "Found legacy attention-flag reads in src/. Migrate to " "`cfg.attn_implementation` / capability flags:\n " + "\n ".join(sorted(offenders)) )