migrate remaining consumers to canonical attn_implementation
This commit is contained in:
@@ -114,10 +114,18 @@ def lm_eval(config: str, cloud: Optional[str] = None):
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
|
||||
# This path operates on raw YAML via DictDefault (not the validated
|
||||
# AxolotlInputConfig), so we resolve flash-attn from either the canonical
|
||||
# `attn_implementation` field or the deprecated `flash_attention` boolean.
|
||||
_flash_attn_impls = {"flash_attention_2", "flash_attention_3"}
|
||||
lm_eval_flash_attention = bool(
|
||||
cfg.flash_attention or cfg.attn_implementation in _flash_attn_impls
|
||||
)
|
||||
|
||||
for lm_eval_args in build_lm_eval_command(
|
||||
cfg.lm_eval_tasks,
|
||||
bfloat16=cfg.bfloat16 or cfg.bf16,
|
||||
flash_attention=cfg.flash_attention,
|
||||
flash_attention=lm_eval_flash_attention,
|
||||
output_dir=cfg.output_dir,
|
||||
batch_size=cfg.lm_eval_batch_size,
|
||||
wandb_project=cfg.wandb_project,
|
||||
|
||||
@@ -521,9 +521,9 @@ class TestMultiGPULlama:
|
||||
}
|
||||
)
|
||||
if attention_backend == "flash":
|
||||
cfg.flash_attention = True
|
||||
cfg.attn_implementation = "flash_attention_2"
|
||||
elif attention_backend == "flex":
|
||||
cfg.flex_attention = True
|
||||
cfg.attn_implementation = "flex_attention"
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
70
tests/test_no_legacy_attn_reads.py
Normal file
70
tests/test_no_legacy_attn_reads.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Guard the attn_implementation source-of-truth invariant.
|
||||
|
||||
`cfg.attn_implementation` is the single source of truth for the attention
|
||||
backend on the validated config. Legacy boolean flags (`flash_attention`,
|
||||
`sdp_attention`, `xformers_attention`, `flex_attention`, `sage_attention`,
|
||||
`s2_attention`, `eager_attention`) are input-only deprecated aliases — they
|
||||
are stripped from `data` by `normalize_attn_implementation` and must never be
|
||||
read downstream.
|
||||
|
||||
This test greps `src/` and fails if it finds a `cfg.<legacy>_attention` read.
|
||||
If you're here because this test failed, migrate the read site to
|
||||
`cfg.attn_implementation` or one of the `attn_supports_packing /
|
||||
attn_uses_flash_lib / attn_needs_dtype_cast` computed capability flags.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
LEGACY_FLAGS = (
|
||||
"flash_attention",
|
||||
"sdp_attention",
|
||||
"xformers_attention",
|
||||
"flex_attention",
|
||||
"sage_attention",
|
||||
"s2_attention",
|
||||
"eager_attention",
|
||||
)
|
||||
|
||||
# The normalizer is allowed to read the legacy keys (that's its job).
|
||||
# lm_eval/cli.py is a raw-YAML entry point (bypasses AxolotlInputConfig) that
|
||||
# honors both forms during the deprecation period — when we remove the legacy
|
||||
# flags entirely, drop this allowlist entry and the BC branch in that file.
|
||||
ALLOWED_FILES = {
|
||||
Path("src/axolotl/utils/schemas/config.py"),
|
||||
Path("src/axolotl/integrations/lm_eval/cli.py"),
|
||||
}
|
||||
|
||||
# `cfg.<flag>`, `self.cfg.<flag>`, `data.get("<flag>")`, `data["<flag>"]`
|
||||
_PATTERNS = [re.compile(rf"\bcfg\.{flag}\b") for flag in LEGACY_FLAGS] + [
|
||||
re.compile(rf'\bdata\.get\("{flag}"\)') for flag in LEGACY_FLAGS
|
||||
]
|
||||
|
||||
|
||||
def _repo_root() -> Path:
|
||||
return Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
def test_no_legacy_attn_reads_in_src():
|
||||
root = _repo_root()
|
||||
src = root / "src"
|
||||
offenders: list[str] = []
|
||||
|
||||
for py_file in src.rglob("*.py"):
|
||||
rel = py_file.relative_to(root)
|
||||
if rel in ALLOWED_FILES:
|
||||
continue
|
||||
text = py_file.read_text(encoding="utf-8")
|
||||
for pattern in _PATTERNS:
|
||||
for match in pattern.finditer(text):
|
||||
# Line number for the user's convenience.
|
||||
line_no = text.count("\n", 0, match.start()) + 1
|
||||
offenders.append(f"{rel}:{line_no} {match.group(0)}")
|
||||
|
||||
assert not offenders, (
|
||||
"Found legacy attention-flag reads in src/. Migrate to "
|
||||
"`cfg.attn_implementation` / capability flags:\n "
|
||||
+ "\n ".join(sorted(offenders))
|
||||
)
|
||||
Reference in New Issue
Block a user