From a0d24bcc192ed9e948a4089cea6e31b1138495bd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 23 Apr 2026 21:26:18 +0000 Subject: [PATCH] migrate remaining consumers to canonical attn_implementation --- src/axolotl/integrations/lm_eval/cli.py | 10 +++- tests/e2e/multigpu/test_llama.py | 4 +- tests/test_no_legacy_attn_reads.py | 70 +++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 3 deletions(-) create mode 100644 tests/test_no_legacy_attn_reads.py diff --git a/src/axolotl/integrations/lm_eval/cli.py b/src/axolotl/integrations/lm_eval/cli.py index 4b905d476..a20f4d154 100644 --- a/src/axolotl/integrations/lm_eval/cli.py +++ b/src/axolotl/integrations/lm_eval/cli.py @@ -114,10 +114,18 @@ def lm_eval(config: str, cloud: Optional[str] = None): with open(config, encoding="utf-8") as file: cfg: DictDefault = DictDefault(yaml.safe_load(file)) + # This path operates on raw YAML via DictDefault (not the validated + # AxolotlInputConfig), so we resolve flash-attn from either the canonical + # `attn_implementation` field or the deprecated `flash_attention` boolean. + _flash_attn_impls = {"flash_attention_2", "flash_attention_3"} + lm_eval_flash_attention = bool( + cfg.flash_attention or cfg.attn_implementation in _flash_attn_impls + ) + for lm_eval_args in build_lm_eval_command( cfg.lm_eval_tasks, bfloat16=cfg.bfloat16 or cfg.bf16, - flash_attention=cfg.flash_attention, + flash_attention=lm_eval_flash_attention, output_dir=cfg.output_dir, batch_size=cfg.lm_eval_batch_size, wandb_project=cfg.wandb_project, diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 1e3757dcf..b89c93522 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -521,9 +521,9 @@ class TestMultiGPULlama: } ) if attention_backend == "flash": - cfg.flash_attention = True + cfg.attn_implementation = "flash_attention_2" elif attention_backend == "flex": - cfg.flex_attention = True + cfg.attn_implementation = "flex_attention" # write cfg to yaml file Path(temp_dir).mkdir(parents=True, exist_ok=True) diff --git a/tests/test_no_legacy_attn_reads.py b/tests/test_no_legacy_attn_reads.py new file mode 100644 index 000000000..ebeaac354 --- /dev/null +++ b/tests/test_no_legacy_attn_reads.py @@ -0,0 +1,70 @@ +"""Guard the attn_implementation source-of-truth invariant. + +`cfg.attn_implementation` is the single source of truth for the attention +backend on the validated config. Legacy boolean flags (`flash_attention`, +`sdp_attention`, `xformers_attention`, `flex_attention`, `sage_attention`, +`s2_attention`, `eager_attention`) are input-only deprecated aliases — they +are stripped from `data` by `normalize_attn_implementation` and must never be +read downstream. + +This test greps `src/` and fails if it finds a `cfg._attention` read. +If you're here because this test failed, migrate the read site to +`cfg.attn_implementation` or one of the `attn_supports_packing / +attn_uses_flash_lib / attn_needs_dtype_cast` computed capability flags. +""" + +from __future__ import annotations + +import re +from pathlib import Path + +LEGACY_FLAGS = ( + "flash_attention", + "sdp_attention", + "xformers_attention", + "flex_attention", + "sage_attention", + "s2_attention", + "eager_attention", +) + +# The normalizer is allowed to read the legacy keys (that's its job). +# lm_eval/cli.py is a raw-YAML entry point (bypasses AxolotlInputConfig) that +# honors both forms during the deprecation period — when we remove the legacy +# flags entirely, drop this allowlist entry and the BC branch in that file. +ALLOWED_FILES = { + Path("src/axolotl/utils/schemas/config.py"), + Path("src/axolotl/integrations/lm_eval/cli.py"), +} + +# `cfg.`, `self.cfg.`, `data.get("")`, `data[""]` +_PATTERNS = [re.compile(rf"\bcfg\.{flag}\b") for flag in LEGACY_FLAGS] + [ + re.compile(rf'\bdata\.get\("{flag}"\)') for flag in LEGACY_FLAGS +] + + +def _repo_root() -> Path: + return Path(__file__).resolve().parent.parent + + +def test_no_legacy_attn_reads_in_src(): + root = _repo_root() + src = root / "src" + offenders: list[str] = [] + + for py_file in src.rglob("*.py"): + rel = py_file.relative_to(root) + if rel in ALLOWED_FILES: + continue + text = py_file.read_text(encoding="utf-8") + for pattern in _PATTERNS: + for match in pattern.finditer(text): + # Line number for the user's convenience. + line_no = text.count("\n", 0, match.start()) + 1 + offenders.append(f"{rel}:{line_no} {match.group(0)}") + + assert not offenders, ( + "Found legacy attention-flag reads in src/. Migrate to " + "`cfg.attn_implementation` / capability flags:\n " + + "\n ".join(sorted(offenders)) + )