From 2579c496d55feab35d95bdb5d8e75f6cd9bbca3c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 23 Apr 2026 21:17:10 +0000 Subject: [PATCH] make attn_implementation the single source of truth --- ATTN_REFACTOR_REVIEW.md | 142 ++++++++ src/axolotl/core/builders/causal.py | 4 +- src/axolotl/integrations/lm_eval/__init__.py | 2 +- src/axolotl/loaders/model.py | 38 +- src/axolotl/loaders/patch_manager.py | 2 +- src/axolotl/loaders/tokenizer.py | 2 +- src/axolotl/utils/schemas/config.py | 225 ++++++------ src/axolotl/utils/schemas/enums.py | 83 ++++- src/axolotl/utils/schemas/validation.py | 32 +- tests/test_attn_implementation.py | 348 ++++++++----------- 10 files changed, 491 insertions(+), 387 deletions(-) create mode 100644 ATTN_REFACTOR_REVIEW.md diff --git a/ATTN_REFACTOR_REVIEW.md b/ATTN_REFACTOR_REVIEW.md new file mode 100644 index 000000000..84b3b6a0b --- /dev/null +++ b/ATTN_REFACTOR_REVIEW.md @@ -0,0 +1,142 @@ +# `attn-implementation-refactor` branch review + +Review target: `attn-implementation-refactor` (5 commits ahead of main, merge base `69904781`). +Scope: 16 files, +682 / −96. + +## 1. What the branch is trying to do + +Collapse seven boolean attention flags (`flash_attention`, `sdp_attention`, `xformers_attention`, `flex_attention`, `sage_attention`, `s2_attention`, `eager_attention`) into a single `attn_implementation` field, with derived capability flags (`attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast`) for the gates that used to be ad-hoc OR-chains. + +Mechanism: `normalize_attn_implementation` (a `@model_validator(mode="before")` on `AxolotlInputConfig`) maps bidirectionally between the new field and the legacy flags, with a priority list for legacy combos (`s2 + flash → s2`), and then computes the three capability flags from frozen sets in `enums.py`. + +Adjacent changes: `xformers` and `sage` now register as their own entries in `ALL_ATTENTION_FUNCTIONS` (with FA2 mask behavior) instead of stomping the `flash_attention_2` slot. New `fp8` backend wires `torchao.prototype.attention.apply_low_precision_attention` in `apply_post_model_load_patches`. + +## 2. Target design + +**`cfg.attn_implementation` is the single source of truth on the validated config.** + +- Its type is `str | None`. Accepted values are **canonical names only** — no short-form aliases: + - HF-native: `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention`. (`flash_attention_3` is net-new to axolotl — the current branch only encodes `flash_attention_2` under the short name `flash`.) + - Axolotl-owned (registered into `ALL_ATTENTION_FUNCTIONS` under exactly these names): `xformers`, `sage`, `s2`, `fp8`. + - Hub-kernel paths: `kernels-community/sage-attention`, `kernels-community/flash-attn3`, etc. — passthrough. Known-kernel allowlist in `enums.py` classifies the common ones into the capability tables. + Short forms like `flash`, `fa2`, `fa3`, `sdp`, `flex` are rejected (Pydantic validation error with a pointer to the canonical name). +- `model.py:_set_attention_config` passes `cfg.attn_implementation` to HF verbatim — no `_ATTN_IMPL_TO_HF` translation dict needed. +- Legacy booleans (`flash_attention: true`, `sdp_attention: true`, …) are the **only** input aliases, kept for backwards compatibility. The normalizer maps them to the canonical `attn_implementation` value, emits a one-time `DeprecationWarning` per flag, and removes them from `data` so they're never readable on the validated `cfg`. `deprecated=True` on each Field surfaces this in JSON schema. Mapping is 1:1 with the current legacy-flag semantics (`flash_attention → flash_attention_2`, `sdp_attention → sdpa`, `flex_attention → flex_attention`, `xformers_attention → xformers`, `sage_attention → sage`, `s2_attention → s2`, `eager_attention → eager`). +- Capability flags (`attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast`) are **`@computed_field` on the model**, not settable inputs. Lookup is keyed by the canonical `attn_implementation` string. +- Unknown / user-supplied strings (custom hub kernels) pass through to HF but get **conservative capability defaults** (packing=False, flash-lib=False, dtype-cast=True). Known hub kernels axolotl can classify live in a small allowlist. +- Downstream consumers read *only* `cfg.attn_implementation` and the capability flags. No `cfg.flash_attention`, `cfg.xformers_attention`, etc. anywhere in `src/`. + +This is strictly what the branch is already *trying* to do — the gaps below are places it hasn't landed that goal yet. + +## 3. Gaps and holes + +### A. Legacy flags are still parallel state, not input-only + +1. The normalizer *sets* the legacy flags on `data` (`impl_to_flag[attn_impl]` branch). It does not delete them. So `cfg.flash_attention` is still truthy after validation, and downstream code still reads it (see G). +2. Short-form enum values (`flash`, `sdpa`, `fp8`) are persisted as-is on `cfg.attn_implementation`, which is why `model.py` needs `_ATTN_IMPL_TO_HF` to translate before passing to HF. Source-of-truth implies canonicalize at normalize-time, not translate at consume-time. +3. Legacy flag + `attn_implementation` (consistent combo, e.g. `attn_implementation: flash + flash_attention: true`) emits no deprecation warning — only legacy-only path warns. +4. Legacy Field descriptions (`xformers_attention`, `sdp_attention`, etc.) don't have `deprecated=True`, so JSON schema still advertises them as first-class. + +### B. Validators that still only check the legacy flag + +5. **`check_ebft_activation_offloading`** (`validation.py:1607-1619`) reads only `data.get("flex_attention")`. Users on `attn_implementation: flex_attention` bypass the incompatibility check. +6. **`check_sample_packing_without_attention`** (`validation.py:188-203`) early-returns when `attn_implementation` is set but never validates the chosen backend actually supports packing. `attn_implementation: eager + sample_packing: true` silently passes; the old legacy-flag check warned. + +### C. Non-enum strings fall through the capability tables + +7. **HF-native `"flash_attention_2"`** is neither in `impl_to_flag` nor `FLASH_ATTN_LIB_IMPLS`. A user copy-pasting from HF docs gets `attn_uses_flash_lib=False`, silently disabling FA4 auto-apply, LLaMA flash hijack, `_patch_attention` (btlm, stablelm_epoch, mistral3, llava), and `_apply_flash_attention_peft_patches`. +8. **Hub kernel strings** (`kernels-community/flash-attn3`, `kernels-community/sage-attention`) default to `attn_supports_packing=True` (silently enters multipack with varlen `position_ids` — correctness depends on the kernel honoring them) and `attn_uses_flash_lib=False` (so `context_parallel_size > 1` raises "requires flash attention" even for FA3 hub kernels). +9. **Conflict trap for hub-kernel + legacy flag** (`config.py:1414-1419`): `attn_implementation: kernels-community/flash-attn3 + flash_attention: true` always raises, because `impl_to_flag.get(custom) is None` and the loop treats `flag != None` as conflict. Common combo in existing YAMLs breaks hard on upgrade. + +### D. Silent behaviour change for xformers + +10. Old `_apply_flash_attention_patches` did `self.cfg.flash_attention = True` for `xformers + sample_packing`. The new version doesn't, and xformers is not in `FLASH_ATTN_LIB_IMPLS`. Consumers that keyed off `cfg.flash_attention` now see falsy for xformers, silently dropping `_patch_attention` (btlm / stablelm_epoch+packing / mistral3 / llava model-type FA patches). Some of this is arguably correct cleanup (xformers has its own HF registry entry now), but the btlm/stablelm/mistral3 regression is not called out and not tested. Decide consciously, not by omission. + +### E. Capability flags are writable Pydantic fields, not computed + +11. `attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast` are declared `bool | None = Field(default=None)` on `AxolotlInputConfig`. YAML is not rejected — a user can set `attn_uses_flash_lib: true` and override the normalizer. + +### F. Validator ordering (not covered by tests) + +12. `AttentionValidationMixin.check_attention_fields` (inherited, `mode="before"`) and `normalize_attn_implementation` (subclass, `mode="before"`) both run during `model_validator` phase. Pydantic MRO may run the inherited one first. For legacy-only `s2_attention: true + flash_attention: true` (the test `test_s2_plus_flash_maps_to_s2` asserts this maps to `s2`), the inherited check may raise "multiple attention implementations set" before the normalizer runs. The test calls the classmethod directly and does not build the model, so this is unverified either way. + +### G. Remaining legacy reads in `src/` + +13. `src/axolotl/integrations/lm_eval/cli.py:120` reads `cfg.flash_attention`. Works for `attn_implementation=flash` only. +14. `tests/e2e/multigpu/test_llama.py:524-526` writes `cfg.flash_attention = True` / `cfg.flex_attention = True`. Stale pattern. +15. Dual-check idioms in `config.py` (lines 1464, 1478, 1570, 1586, 1774) and `validation.py` (198, 209, 221, 850, 1586, 1611) — `data.get("x_attention") or data.get("attn_implementation") == "x"` — are redundant once legacy flags are input-only; remove them. + +### H. fp8 operational risk + +16. The `fp8` docstring documents hard requirements (PyTorch ≥ 2.11, SM90+, flash-attn with FA3, torchao ≥ 0.17.0) and a runtime constraint (`config.use_cache = False`). None are validated — misconfig surfaces as a torchao runtime error. `xformers` and `sage` availability/compute-capability guards exist; `fp8` should match. + +### I. Test coverage gaps + +17. `test_attn_implementation.py` exercises the classmethod in isolation plus the constant sets. It does **not**: + - Build a full `AxolotlInputConfig(**data)` (so validator ordering — item 12 — is untested). + - Verify capability flags can't be overridden from YAML (item 11). + - Cover `check_sample_packing_without_attention` with `attn_implementation: eager` (item 6). + - Cover `check_ebft_activation_offloading` with `attn_implementation: flex_attention` (item 5). + - Cover hub-kernel + legacy flag combo (item 9). + - Cover `flash_attention_2` canonicalization (item 7). + +## 4. Fix plan + +Four phases, each commit-sized. Phases 1–2 are local and low-risk; phase 3 is the behaviour-changing cleanup; phase 4 is tests + docs. + +### Phase 1 — Lock down the data model + +1. Drop the `AttnImplementation` enum. `attn_implementation` becomes `str | None`, validated against a canonical allowlist (`eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention`, `xformers`, `sage`, `s2`, `fp8`) **or** a hub-kernel path (`startswith("kernels-")` / contains `/`). Reject short-form strings like `flash` / `fa2` / `sdp` / `flex` with an explicit error pointing at the canonical name. +2. Rewrite `normalize_attn_implementation` so its only job is mapping **legacy booleans → canonical `attn_implementation`** (for BC). Mapping is fixed: + - `flash_attention → flash_attention_2` + - `sdp_attention → sdpa` + - `flex_attention → flex_attention` + - `xformers_attention → xformers` + - `sage_attention → sage` + - `s2_attention → s2` + - `eager_attention → eager` + Priority for legacy combos stays as in the current branch (`s2 > sage > xformers > flex > flash > sdp > eager`). Emit a one-time `DeprecationWarning` per unique legacy flag seen. After mapping, delete the legacy flag keys from `data` so they never appear on validated `cfg`. If both a canonical `attn_implementation` and any legacy flag are set, raise (no silent precedence). + + **Merge `AttentionValidationMixin.check_attention_fields` into this normalizer and delete the mixin method.** Pydantic v2 runs inherited `mode="before"` validators before subclass ones per MRO, so leaving them as siblings causes the inherited check to reject legacy combos like `s2 + flash` before the normalizer can map them. One validator, one source of conflict detection. + + **Fix the gemma4-hybrid path**: change `data["attn_implementation"] = "flash"` to `data["attn_implementation"] = "flash_attention_2"` (the short name no longer validates after step 1). +3. Convert `attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast` to `@computed_field`. The three capability tables move to `enums.py` as module constants keyed by the canonical `attn_implementation` string (including `flash_attention_3` — missing from the current branch — and known hub kernels): + - Packing-capable: `{flash_attention_2, flash_attention_3, flex_attention, xformers, sage, kernels-community/flash-attn3, kernels-community/sage-attention}`. + - Flash-lib (axolotl's monkeypatch targets): `{flash_attention_2, flash_attention_3, s2, kernels-community/flash-attn3}`. + - No-dtype-cast: `{eager, sdpa}`. + Unknown strings: conservative defaults (`packing=False, flash_lib=False, dtype_cast=True`). +4. Delete `_ATTN_IMPL_TO_HF` from `model.py` and pass `cfg.attn_implementation` straight through. The gemma4-hybrid branch continues to override to `flash_attention_2` before passing to HF. +5. `deprecated=True` on each legacy boolean Field so JSON schema + Pydantic surface the deprecation. + +### Phase 2 — Fix the validators + +6. `check_sample_packing_without_attention`: drop the early-return and gate on `attn_supports_packing`. Warn (or raise — pick one and be consistent) if packing is enabled with a non-packing backend. +7. `check_ebft_activation_offloading`: replace `data.get("flex_attention")` with `attn_implementation == "flex_attention"`. +8. Sweep items (item 15): remove every `data.get("x_attention") or data.get("attn_implementation") == "x"` dual-check — after phase 1 the legacy side is always `None`. Reduces ~10 lines of noise and eliminates the "which side wins" class of bugs. +9. fp8 preflight (item 16): require `env_capabilities.compute_capability ≥ sm_90`, `torch_version ≥ 2.11`, and `torchao_version ≥ 0.17`. Warn if `use_cache` isn't explicitly `False`. + +### Phase 3 — Migrate remaining consumers + +10. `lm_eval/cli.py:120` → `flash_attention=cfg.attn_uses_flash_lib`. +11. `lm_eval/__init__.py:26` currently reads `(cfg.attn_implementation == "flash")` — after canonicalization `"flash"` is never stored, so this evaluates `False` for every backend. Change to `cfg.attn_uses_flash_lib`. +12. `validation.py:1137-1142` (NPU check) currently iterates `["flash_attention", "sdp_attention", "s2_attention"]` as string keys. Replace with `cfg.attn_implementation in {"flash_attention_2", "flash_attention_3", "sdpa", "s2"}` or the equivalent canonical-string set. +13. `tests/e2e/multigpu/test_llama.py:524-526` → `cfg.attn_implementation = "flash_attention_2"` / `"flex_attention"`. +14. **Xformers decision** (item 10): the old `cfg.flash_attention = True` side-effect activated `_patch_attention` for btlm/stablelm_epoch+packing/mistral3/llava. Two choices: + - Add `xformers` to the set that gates `_patch_attention` (restore old behaviour, keeps patches live). + - Document that those patches don't apply to xformers post-refactor and drop the paths if they're dead. + Pick one explicitly and leave a commit note. Do not leave it as silent breakage. +15. Add a repo-level check (`tests/test_no_legacy_attn_reads.py` or a ruff/grep pre-commit) that fails if anything outside `config.py`'s normalizer reads `cfg.flash_attention` / `cfg.sdp_attention` / etc. Keeps the invariant from rotting. + +### Phase 4 — Tests + docs + +14. Rewrite `test_attn_implementation.py` to build full `AxolotlInputConfig(**data)`, not just the classmethod. Covers validator ordering and the Pydantic-field-override issue. +15. Add one test per gap closed above: `attn_implementation: eager + sample_packing`; `attn_implementation: flex_attention + activation_offloading`; short-form `flash` rejected; `flash_attention_2` passthrough; `kernels-community/flash-attn3` capability lookup; `attn_uses_flash_lib: true` in YAML rejected; legacy boolean emits `DeprecationWarning` and is absent from validated `cfg`; fp8 preflight failures. +16. Update `docs/attention.qmd` for the single `attn_implementation` field + the deprecation table for legacy flags. One-paragraph migration note in the changelog. +17. `examples/` contains ~170 YAML files using legacy flags (`flash_attention: true` etc.). They still validate post-refactor (normalizer maps them with deprecation), but a follow-up sweep to convert them to `attn_implementation: flash_attention_2` is worth scheduling — call this out in the migration note so users know examples will be migrated on a later pass. + +## 5. Ordering & risk + +- Phase 1 is the keystone: it's the largest diff but each step is mechanical once the alias map is in place. No behaviour change for any consumer that was using `attn_implementation` correctly; behaviour change only for consumers that were reading legacy flags (phase 3 step 13 is the explicit decision point). +- Phase 2 is independent of phase 1 and can land first as a quick safety net. +- Phase 3 step 13 is the only judgment call — flag for review before choosing. +- Total: ~10-13 commits beyond what's on the branch, each scoped and individually revertable. diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index d0b298bee..15624173d 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -502,11 +502,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): # Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention, # supported multipack models, or non-flash-attention llama if ( - self.cfg.attn_implementation == "flex" + self.cfg.attn_implementation == "flex_attention" or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or ( self.cfg.model_config_type in ["llama"] - and self.cfg.attn_implementation != "flash" + and self.cfg.attn_implementation != "flash_attention_2" ) ): collator = V2BatchSamplerDataCollatorForSeq2Seq diff --git a/src/axolotl/integrations/lm_eval/__init__.py b/src/axolotl/integrations/lm_eval/__init__.py index 53386956a..732ce1592 100644 --- a/src/axolotl/integrations/lm_eval/__init__.py +++ b/src/axolotl/integrations/lm_eval/__init__.py @@ -23,7 +23,7 @@ class LMEvalPlugin(BasePlugin): for lm_eval_args in build_lm_eval_command( cfg.lm_eval_tasks, bfloat16=cfg.bfloat16 or cfg.bf16, - flash_attention=(cfg.attn_implementation == "flash"), + flash_attention=cfg.attn_uses_flash_lib, output_dir=cfg.output_dir, batch_size=cfg.lm_eval_batch_size, wandb_project=cfg.wandb_project, diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 0847c9b79..997e0739d 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -628,33 +628,25 @@ class ModelLoader: ) def _set_attention_config(self): - """Sample packing uses custom FA2 patch""" - # Map attn_implementation enum values to HF attn_implementation strings. - # xformers/sage are registered in ALL_ATTENTION_FUNCTIONS and - # ALL_MASK_ATTENTION_FUNCTIONS under their own names with FA2 mask - # behavior, so they no longer need to masquerade as flash_attention_2. - # s2 still uses flash_attention_2 because it modifies FA2 internals. - # Hub kernel strings (e.g. "kernels-community/flash-attn3") fall - # through the .get() and are passed to HF unchanged. - _ATTN_IMPL_TO_HF = { - "eager": "eager", - "flash": "flash_attention_2", - "sdpa": "sdpa", - "xformers": "xformers", - "flex": "flex_attention", - "sage": "sage", - "s2": "flash_attention_2", - "fp8": "sdpa", - } + # s2 and fp8 need a different HF backend at load time than their + # canonical name: s2 patches FA2 internals, so load under FA2; fp8 + # replaces F.scaled_dot_product_attention post-load, so load under sdpa. + # Every other canonical name (and hub-kernel paths) is passed through + # verbatim — xformers/sage/flash_attention_* are registered under their + # own names in ALL_ATTENTION_FUNCTIONS before model load. + _LOAD_TIME_OVERRIDE = {"s2": "flash_attention_2", "fp8": "sdpa"} if self.cfg.gemma4_hybrid_attn_impl: - # Load model with flash_attention_2 for sliding window layers; - # global layers will be patched to sdpa post-load. - self.model_kwargs["attn_implementation"] = "flash_attention_2" - self.model_config._attn_implementation = "flash_attention_2" + # Load with flash_attention_2 for sliding-window layers; global + # layers are swapped to sdpa post-load. + hf_impl = "flash_attention_2" elif self.cfg.attn_implementation: - hf_impl = _ATTN_IMPL_TO_HF.get( + hf_impl = _LOAD_TIME_OVERRIDE.get( self.cfg.attn_implementation, self.cfg.attn_implementation ) + else: + hf_impl = None + + if hf_impl is not None: self.model_kwargs["attn_implementation"] = hf_impl self.model_config._attn_implementation = hf_impl diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 5783c5996..68952014f 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -333,7 +333,7 @@ class PatchManager: def _apply_flex_attention_patches(self): """Apply patches for flexible attention.""" - if self.cfg.attn_implementation == "flex": + if self.cfg.attn_implementation == "flex_attention": from axolotl.monkeypatch.attention.flex_attn import ( patch_flex_wrapper, ) diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index 48f4a9fa3..572a880bd 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -207,7 +207,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: # Mistral's official FA implementation requires left padding if ( cfg.is_mistral_derived_model - and cfg.attn_implementation == "flash" + and cfg.attn_implementation == "flash_attention_2" and not cfg.sample_packing ): tokenizer.padding_side = "left" diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 260cb2169..fb9cf3bde 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -10,7 +10,9 @@ from pydantic import ( BaseModel, Field, StringConstraints, + computed_field, field_serializer, + field_validator, model_validator, ) @@ -28,10 +30,12 @@ from axolotl.utils.schemas.datasets import ( from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig from axolotl.utils.schemas.enums import ( - _NO_DTYPE_CAST_ATTN_IMPLS, - _NON_PACKING_ATTN_IMPLS, - FLASH_ATTN_LIB_IMPLS, - AttnImplementation, + ATTN_IMPLS_SUPPORTING_PACKING, + ATTN_IMPLS_USING_FLASH_LIB, + ATTN_IMPLS_WITHOUT_DTYPE_CAST, + CANONICAL_ATTN_IMPLS, + LEGACY_ATTN_FLAG_TO_IMPL, + SHORT_FORM_ALIAS_TO_CANONICAL, ChatTemplate, RingAttnFunc, RLType, @@ -739,28 +743,35 @@ class AxolotlInputConfig( xformers_attention: bool | None = Field( default=None, + deprecated="Use `attn_implementation: xformers` instead.", json_schema_extra={ - "description": "Whether to use xformers attention patch https://github.com/facebookresearch/xformers" + "description": "[DEPRECATED] Use `attn_implementation: xformers`. https://github.com/facebookresearch/xformers" }, ) sdp_attention: bool | None = Field( default=None, + deprecated="Use `attn_implementation: sdpa` instead.", json_schema_extra={ - "description": "Whether to use scaled-dot-product attention https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html" + "description": "[DEPRECATED] Use `attn_implementation: sdpa`." }, ) s2_attention: bool | None = Field( default=None, + deprecated="Use `attn_implementation: s2` instead.", json_schema_extra={ - "description": "Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf" + "description": "[DEPRECATED] Use `attn_implementation: s2`. Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf" }, ) - flex_attention: bool | None = None + flex_attention: bool | None = Field( + default=None, + deprecated="Use `attn_implementation: flex_attention` instead.", + ) flex_attn_compile_kwargs: dict[str, Any] | None = None flash_attention: bool | None = Field( default=None, + deprecated="Use `attn_implementation: flash_attention_2` instead.", json_schema_extra={ - "description": "Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention" + "description": "[DEPRECATED] Use `attn_implementation: flash_attention_2`. https://github.com/Dao-AILab/flash-attention" }, ) flash_attn_cross_entropy: bool | None = Field( @@ -787,17 +798,26 @@ class AxolotlInputConfig( ) sage_attention: bool | None = Field( default=None, + deprecated="Use `attn_implementation: sage` instead.", json_schema_extra={ - "description": "Whether to use SageAttention https://github.com/thu-ml/SageAttention" + "description": "[DEPRECATED] Use `attn_implementation: sage`. https://github.com/thu-ml/SageAttention" }, ) - eager_attention: bool | None = None + eager_attention: bool | None = Field( + default=None, + deprecated="Use `attn_implementation: eager` instead.", + ) - attn_implementation: AttnImplementation | str | None = Field( + attn_implementation: str | None = Field( default=None, json_schema_extra={ - "description": "Attention backend: eager, flash, sdpa, xformers, flex, sage, s2, fp8, or a custom string for kernels." + "description": ( + "Attention backend. Canonical values: eager, sdpa, flash_attention_2, " + "flash_attention_3, flex_attention, xformers, sage, s2, fp8. Hub-kernel " + "paths (e.g. kernels-community/flash-attn3) are also accepted and passed " + "through to transformers." + ) }, ) @@ -1335,29 +1355,24 @@ class AxolotlInputConfig( return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs] return None - # --- Attention capability flags (computed by normalize_attn_implementation) --- + # --- Attention capability flags (derived from attn_implementation) --- - attn_supports_packing: bool | None = Field( - default=None, - json_schema_extra={ - "description": "Whether the attention backend supports varlen sample packing. " - "Computed automatically from attn_implementation." - }, - ) - attn_uses_flash_lib: bool | None = Field( - default=None, - json_schema_extra={ - "description": "Whether the attention backend requires axolotl's flash_attn " - "monkeypatches. Computed automatically from attn_implementation." - }, - ) - attn_needs_dtype_cast: bool | None = Field( - default=None, - json_schema_extra={ - "description": "Whether the attention backend needs embedding dtype cast to " - "fp16/bf16. Computed automatically from attn_implementation." - }, - ) + @computed_field # type: ignore[misc] + @property + def attn_supports_packing(self) -> bool: + return self.attn_implementation in ATTN_IMPLS_SUPPORTING_PACKING + + @computed_field # type: ignore[misc] + @property + def attn_uses_flash_lib(self) -> bool: + return self.attn_implementation in ATTN_IMPLS_USING_FLASH_LIB + + @computed_field # type: ignore[misc] + @property + def attn_needs_dtype_cast(self) -> bool: + if self.attn_implementation is None: + return False + return self.attn_implementation not in ATTN_IMPLS_WITHOUT_DTYPE_CAST @model_validator(mode="before") @classmethod @@ -1382,90 +1397,83 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def normalize_attn_implementation(cls, data): - """Normalize attention config: map between attn_implementation enum and legacy boolean flags.""" + """Map legacy boolean attention flags to the canonical `attn_implementation`. + + `attn_implementation` is the single source of truth on the validated + config. Legacy booleans (`flash_attention: true`, …) are input-only + aliases; this validator warns, maps them to their canonical value, and + strips them from `data` so they cannot be read downstream. + + Raises if a canonical `attn_implementation` is set alongside any legacy + boolean — users must pick one. + """ + if not isinstance(data, dict): + return data + attn_impl = data.get("attn_implementation") + set_flags = [f for f in LEGACY_ATTN_FLAG_TO_IMPL if data.get(f)] - # If gemma4_hybrid_attn_impl is set but no attn_implementation, default - # to flash (the sliding-window layers use FA2, and packing should be enabled). - if data.get("gemma4_hybrid_attn_impl") and not attn_impl: - data["attn_implementation"] = "flash" - attn_impl = "flash" - - # Mapping: attn_implementation value -> primary legacy flag to set - impl_to_flag = { - "eager": "eager_attention", - "flash": "flash_attention", - "sdpa": "sdp_attention", - "xformers": "xformers_attention", - "flex": "flex_attention", - "sage": "sage_attention", - "s2": "s2_attention", - "fp8": None, # new, no legacy flag - } - - # Reverse mapping: legacy flag -> attn_implementation value - flag_to_impl = { - "eager_attention": "eager", - "flash_attention": "flash", - "sdp_attention": "sdpa", - "xformers_attention": "xformers", - "flex_attention": "flex", - "sage_attention": "sage", - "s2_attention": "s2", - } - - # Find which legacy flags are set - set_flags = [f for f, impl in flag_to_impl.items() if data.get(f)] + # gemma4_hybrid defaults to flash_attention_2 when user didn't pick a + # backend. The sliding-window layers run under FA2; post-load patching + # swaps global layers to sdpa (see `_apply_gemma_hybrid_attention`). + if data.get("gemma4_hybrid_attn_impl") and not attn_impl and not set_flags: + data["attn_implementation"] = "flash_attention_2" + attn_impl = "flash_attention_2" if attn_impl and set_flags: - # Both set — check consistency - expected_flag = impl_to_flag.get(attn_impl) - for flag in set_flags: - if flag != expected_flag: - raise ValueError( - f"attn_implementation={attn_impl!r} conflicts with {flag}=true. " - f"Use only attn_implementation or the legacy flag, not both." - ) - elif attn_impl and not set_flags: - # attn_implementation set, no legacy flags — set primary for backwards compat - flag = impl_to_flag.get(attn_impl) - if flag: - data[flag] = True - elif not attn_impl and set_flags: - # Legacy flags set, no attn_implementation — map to enum, warn - # Priority: specific backends first, then generic flash/sdp/eager - priority = [ - "xformers_attention", - "s2_attention", - "sage_attention", - "flex_attention", - "flash_attention", - "sdp_attention", - "eager_attention", - ] - for flag in priority: + raise ValueError( + f"attn_implementation={attn_impl!r} cannot be combined with legacy " + f"attention flags ({', '.join(sorted(set_flags))}). The legacy " + f"flags are deprecated — set only `attn_implementation`." + ) + + if not attn_impl and set_flags: + # Priority: specific backends beat generic flash/sdp/eager fallbacks. + for flag in LEGACY_ATTN_FLAG_TO_IMPL: if flag in set_flags: - data["attn_implementation"] = flag_to_impl[flag] + canonical = LEGACY_ATTN_FLAG_TO_IMPL[flag] + data["attn_implementation"] = canonical LOG.warning( - "`%s: true` is deprecated. Use `attn_implementation: %s` instead.", + "`%s: true` is deprecated and will be removed in a future " + "release. Use `attn_implementation: %s` instead.", flag, - flag_to_impl[flag], + canonical, ) break - # Compute capability flags from the final attn_implementation value - impl = data.get("attn_implementation") - if impl: - data["attn_supports_packing"] = impl not in _NON_PACKING_ATTN_IMPLS - data["attn_uses_flash_lib"] = impl in FLASH_ATTN_LIB_IMPLS - data["attn_needs_dtype_cast"] = impl not in _NO_DTYPE_CAST_ATTN_IMPLS - else: - data["attn_supports_packing"] = False - data["attn_uses_flash_lib"] = False - data["attn_needs_dtype_cast"] = False + # Strip legacy flags from validated data — canonical field is authoritative. + for flag in LEGACY_ATTN_FLAG_TO_IMPL: + data.pop(flag, None) return data + @field_validator("attn_implementation", mode="before") + @classmethod + def validate_attn_implementation(cls, value): + """Accept canonical names and hub-kernel paths; reject short-form aliases.""" + if value is None: + return None + if not isinstance(value, str): + raise TypeError( + f"attn_implementation must be a string, got {type(value).__name__}" + ) + if value in CANONICAL_ATTN_IMPLS: + return value + if "/" in value: + # Hub-kernel path, e.g. "kernels-community/flash-attn3". Pass through. + return value + if value in SHORT_FORM_ALIAS_TO_CANONICAL: + canonical = SHORT_FORM_ALIAS_TO_CANONICAL[value] + raise ValueError( + f"attn_implementation={value!r} is not accepted. " + f"Use the canonical name {canonical!r} instead." + ) + raise ValueError( + f"attn_implementation={value!r} is not a recognized backend. " + f"Expected one of: {sorted(CANONICAL_ATTN_IMPLS)}, or a hub-kernel " + f"path containing '/'." + ) + @model_validator(mode="before") @classmethod def check_sageattn_wo_sample_packing(cls, data): @@ -1763,7 +1771,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): @model_validator(mode="before") @classmethod def check_flex_torch_version(cls, data): - if data.get("flex_attention") or data.get("attn_implementation") == "flex": + if ( + data.get("flex_attention") + or data.get("attn_implementation") == "flex_attention" + ): env_capabilities = data.get("env_capabilities", {}) torch_version = env_capabilities.get("torch_version") diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index f01d4bd7a..edeba6412 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -97,30 +97,75 @@ class CustomSupportedOptimizers(str, Enum): flash_lion = "flash_lion" -class AttnImplementation(str, Enum): - """Attention backend implementations""" +# Canonical values accepted for `attn_implementation`. These are passed to HF +# verbatim via `model.config._attn_implementation`. HF-native backends use HF's +# own names (`flash_attention_2`, `flex_attention`, ...); axolotl-owned backends +# (`xformers`, `sage`, `s2`, `fp8`) register into `ALL_ATTENTION_FUNCTIONS` under +# these exact names. Hub-kernel paths (e.g. `kernels-community/flash-attn3`) are +# not in this set — they pass through the validator via the "/" check. +CANONICAL_ATTN_IMPLS = frozenset( + { + "eager", + "sdpa", + "flash_attention_2", + "flash_attention_3", + "flex_attention", + "xformers", + "sage", + "s2", + "fp8", + } +) - eager = "eager" # pylint: disable=invalid-name - flash = "flash" # pylint: disable=invalid-name - sdpa = "sdpa" # pylint: disable=invalid-name - xformers = "xformers" # pylint: disable=invalid-name - flex = "flex" # pylint: disable=invalid-name - sage = "sage" # pylint: disable=invalid-name - s2 = "s2" # pylint: disable=invalid-name - fp8 = "fp8" # pylint: disable=invalid-name +# Legacy boolean attention flags → canonical `attn_implementation`. Kept for +# backwards compatibility; the normalizer warns and strips these from the +# validated config. Priority order (first match wins) matches the old priority: +# specific backends beat the generic flash/sdp/eager fallbacks. +LEGACY_ATTN_FLAG_TO_IMPL = { + "xformers_attention": "xformers", + "s2_attention": "s2", + "sage_attention": "sage", + "flex_attention": "flex_attention", + "flash_attention": "flash_attention_2", + "sdp_attention": "sdpa", + "eager_attention": "eager", +} +# Short-form aliases that were accepted by the in-progress branch but are +# rejected going forward. Mapped to canonical names only to produce a helpful +# error message pointing users at the right value. +SHORT_FORM_ALIAS_TO_CANONICAL = { + "flash": "flash_attention_2", + "flex": "flex_attention", + "sdp": "sdpa", +} -# Backends that require the flash_attn library (Dao-AILab/flash-attention) -# for axolotl's own monkeypatches (FA4 auto-apply, LLaMA flash hijack, etc.) -FLASH_ATTN_LIB_IMPLS = frozenset({"flash", "s2"}) +# Backends that support varlen sample packing via `position_ids`. +ATTN_IMPLS_SUPPORTING_PACKING = frozenset( + { + "flash_attention_2", + "flash_attention_3", + "flex_attention", + "xformers", + "sage", + "kernels-community/flash-attn3", + "kernels-community/sage-attention", + } +) -# Known backends that do NOT support varlen sample packing via position_ids. -# Used as an exclusion list: unknown strings (e.g., HF hub kernels like -# "kernels-community/flash-attn3") default to packing-capable. -_NON_PACKING_ATTN_IMPLS = frozenset({"eager", "sdpa", "s2", "fp8"}) +# Backends that require the flash_attn library (Dao-AILab/flash-attention) for +# axolotl's own monkeypatches (FA4 auto-apply, LLaMA flash hijack, ring-FA, ...). +ATTN_IMPLS_USING_FLASH_LIB = frozenset( + { + "flash_attention_2", + "flash_attention_3", + "s2", + "kernels-community/flash-attn3", + } +) -# Known backends that do NOT need embedding dtype cast. -_NO_DTYPE_CAST_ATTN_IMPLS = frozenset({"eager", "sdpa"}) +# Backends for which embeddings stay in fp32. Everything else needs fp16/bf16. +ATTN_IMPLS_WITHOUT_DTYPE_CAST = frozenset({"eager", "sdpa"}) class RingAttnFunc(str, Enum): diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 15ab26bfa..6f57da971 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -13,7 +13,7 @@ from transformers.utils.import_utils import is_torch_npu_available from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import ( - _NON_PACKING_ATTN_IMPLS, + ATTN_IMPLS_SUPPORTING_PACKING, ChatTemplate, RingAttnFunc, RLType, @@ -184,26 +184,8 @@ class DatasetValidationMixin: class AttentionValidationMixin: """Validation methods related to attention mechanisms.""" - @model_validator(mode="before") - @classmethod - def check_attention_fields(cls, data): - # If attn_implementation is set, the enum handles mutual exclusivity. - # This validator catches legacy configs with multiple boolean flags. - if data.get("attn_implementation"): - return data - fields = ( - "xformers_attention", - "sdp_attention", - # "s2_attention", # requires both FA and this to be enabled - "flash_attention", - "flex_attention", - "sage_attention", - ) - non_empty_count = sum(1 for field in fields if data.get(field)) - - if non_empty_count > 1: - raise ValueError(f"Only one of {', '.join(fields)} must be set") - return data + # `check_attention_fields` was removed — `AxolotlInputConfig.normalize_attn_implementation` + # is now the single entry point for attention-input mapping and conflict detection. @model_validator(mode="before") @classmethod @@ -238,7 +220,8 @@ class AttentionValidationMixin: @classmethod def check_scaling_softmax_requires_flex(cls, data): if data.get("scaling_softmax") and not ( - data.get("flex_attention") or data.get("attn_implementation") == "flex" + data.get("flex_attention") + or data.get("attn_implementation") == "flex_attention" ): raise ValueError( "scaling_softmax requires flex attention.\n" @@ -956,7 +939,7 @@ class OptimizationValidationMixin: if data.get("batch_flattening"): batch_flattening_auto = data.get("batch_flattening") == "auto" has_varlen_attn = ( - data.get("attn_implementation") not in _NON_PACKING_ATTN_IMPLS + data.get("attn_implementation") in ATTN_IMPLS_SUPPORTING_PACKING if data.get("attn_implementation") else data.get("flash_attention") ) @@ -1683,7 +1666,8 @@ class EBFTValidationMixin: data.get("rl") == "ebft" and data.get("ebft", {}).get("mode") == "strided" and ( - data.get("flex_attention") or data.get("attn_implementation") == "flex" + data.get("flex_attention") + or data.get("attn_implementation") == "flex_attention" ) and data.get("gradient_checkpointing") ): diff --git a/tests/test_attn_implementation.py b/tests/test_attn_implementation.py index 1973d5f74..68282a59b 100644 --- a/tests/test_attn_implementation.py +++ b/tests/test_attn_implementation.py @@ -1,272 +1,205 @@ -""" -Tests for attn_implementation normalization, registry registration, -capability properties, and backwards compatibility with legacy boolean -attention flags. +"""Tests for attn_implementation normalization and capability computation. + +Covers the Phase 1 contract: +- `attn_implementation` accepts canonical names only; short-form aliases are rejected. +- Legacy boolean flags are mapped to the canonical value, warned on, and stripped. +- Canonical `attn_implementation` + legacy flag raises. +- Capability flags are computed from `attn_implementation`. """ import pytest from axolotl.utils.schemas.config import AxolotlInputConfig from axolotl.utils.schemas.enums import ( - _NO_DTYPE_CAST_ATTN_IMPLS, - _NON_PACKING_ATTN_IMPLS, - FLASH_ATTN_LIB_IMPLS, + ATTN_IMPLS_SUPPORTING_PACKING, + ATTN_IMPLS_USING_FLASH_LIB, + ATTN_IMPLS_WITHOUT_DTYPE_CAST, + CANONICAL_ATTN_IMPLS, ) -class TestAttnImplementationNormalizer: - """Test the normalize_attn_implementation validator.""" +class TestNormalizerLegacyMapping: + """Legacy boolean flags map to canonical attn_implementation.""" @staticmethod def _normalize(data): return AxolotlInputConfig.normalize_attn_implementation(data) - # --- Forward mapping: attn_implementation -> legacy flags --- - @pytest.mark.parametrize( - "impl,expected_flag", + "flag,expected", [ - ("eager", "eager_attention"), - ("flash", "flash_attention"), - ("sdpa", "sdp_attention"), - ("flex", "flex_attention"), - ("xformers", "xformers_attention"), - ("sage", "sage_attention"), - ("s2", "s2_attention"), - ], - ) - def test_attn_impl_sets_primary_legacy_flag(self, impl, expected_flag): - data = {"attn_implementation": impl} - result = AxolotlInputConfig.normalize_attn_implementation(data) - assert result.get(expected_flag) is True, ( - f"{impl}: expected {expected_flag}=True" - ) - - @pytest.mark.parametrize("impl", ["xformers", "sage", "s2"]) - def test_attn_impl_does_not_set_flash_for_non_flash(self, impl): - """xformers, sage, s2 should NOT set flash_attention=True anymore.""" - result = self._normalize({"attn_implementation": impl}) - assert not result.get("flash_attention"), ( - f"{impl} should not set flash_attention" - ) - - def test_fp8_sets_no_legacy_flags(self): - result = self._normalize({"attn_implementation": "fp8"}) - for flag in [ - "flash_attention", - "sdp_attention", - "eager_attention", - "xformers_attention", - "sage_attention", - "flex_attention", - "s2_attention", - ]: - assert not result.get(flag), f"fp8 should not set {flag}" - - # --- Reverse mapping: legacy flags -> attn_implementation --- - - @pytest.mark.parametrize( - "flag,expected_impl", - [ - ("flash_attention", "flash"), + ("flash_attention", "flash_attention_2"), ("sdp_attention", "sdpa"), ("xformers_attention", "xformers"), - ("flex_attention", "flex"), + ("flex_attention", "flex_attention"), ("sage_attention", "sage"), ("eager_attention", "eager"), ("s2_attention", "s2"), ], ) - def test_legacy_flag_sets_attn_impl(self, flag, expected_impl): + def test_legacy_flag_maps_to_canonical(self, flag, expected): result = self._normalize({flag: True}) - assert result["attn_implementation"] == expected_impl + assert result["attn_implementation"] == expected - # --- Priority: s2/sage should win over flash when both set --- - - def test_s2_plus_flash_maps_to_s2(self): - """Legacy configs often have both s2_attention and flash_attention.""" - result = self._normalize({"s2_attention": True, "flash_attention": True}) - assert result["attn_implementation"] == "s2" - - def test_sage_plus_flash_maps_to_sage(self): - """sage_attention should take priority over flash_attention.""" - result = self._normalize({"sage_attention": True, "flash_attention": True}) - assert result["attn_implementation"] == "sage" - - # --- Consistency: both set, matching --- - - def test_consistent_both_set_no_error(self): - result = self._normalize( - {"attn_implementation": "flash", "flash_attention": True} - ) - assert result["attn_implementation"] == "flash" - assert result["flash_attention"] is True - - def test_consistent_xformers_with_own_flag(self): - """xformers + xformers_attention should be OK.""" - result = self._normalize( - {"attn_implementation": "xformers", "xformers_attention": True} - ) - assert result["attn_implementation"] == "xformers" - - # --- Conflict detection --- - - def test_conflicting_impl_and_flag_raises(self): - with pytest.raises(ValueError, match="conflicts with"): - self._normalize({"attn_implementation": "flash", "sdp_attention": True}) - - def test_conflicting_xformers_impl_with_sdp_flag(self): - with pytest.raises(ValueError, match="conflicts with"): - self._normalize({"attn_implementation": "xformers", "sdp_attention": True}) - - def test_xformers_with_flash_flag_conflicts(self): - """After normalizer change, xformers no longer expects flash_attention.""" - with pytest.raises(ValueError, match="conflicts with"): - self._normalize( - { - "attn_implementation": "xformers", - "xformers_attention": True, - "flash_attention": True, - } - ) - - def test_s2_with_flash_flag_conflicts(self): - """After normalizer change, s2 no longer expects flash_attention.""" - with pytest.raises(ValueError, match="conflicts with"): - self._normalize( - { - "attn_implementation": "s2", - "s2_attention": True, - "flash_attention": True, - } - ) - - # --- Hub kernel strings pass through --- - - def test_hub_kernel_passthrough(self): - result = self._normalize( - {"attn_implementation": "kernels-community/flash-attn3"} - ) - assert result["attn_implementation"] == "kernels-community/flash-attn3" - # Should not set any legacy flags + def test_legacy_flags_are_stripped_after_mapping(self): + result = self._normalize({"flash_attention": True}) for flag in [ "flash_attention", "sdp_attention", - "eager_attention", "xformers_attention", + "flex_attention", + "sage_attention", + "eager_attention", + "s2_attention", ]: - assert not result.get(flag) + assert flag not in result - def test_custom_string_passthrough(self): - result = self._normalize({"attn_implementation": "my_custom_kernel"}) - assert result["attn_implementation"] == "my_custom_kernel" + def test_s2_plus_flash_priority_is_s2(self): + result = self._normalize({"s2_attention": True, "flash_attention": True}) + assert result["attn_implementation"] == "s2" - # --- No attention set --- + def test_sage_plus_flash_priority_is_sage(self): + result = self._normalize({"sage_attention": True, "flash_attention": True}) + assert result["attn_implementation"] == "sage" + + +class TestNormalizerConflicts: + """Canonical attn_implementation + legacy flag raises.""" + + @staticmethod + def _normalize(data): + return AxolotlInputConfig.normalize_attn_implementation(data) + + def test_canonical_plus_legacy_flag_raises(self): + with pytest.raises(ValueError, match="cannot be combined with legacy"): + self._normalize( + {"attn_implementation": "flash_attention_2", "flash_attention": True} + ) + + def test_canonical_plus_unrelated_legacy_flag_raises(self): + with pytest.raises(ValueError, match="cannot be combined with legacy"): + self._normalize( + {"attn_implementation": "xformers", "flash_attention": True} + ) + + +class TestNormalizerPassthrough: + """Canonical values and hub-kernel paths pass through.""" + + def test_canonical_no_legacy_is_noop(self): + data = {"attn_implementation": "flash_attention_2"} + result = AxolotlInputConfig.normalize_attn_implementation(data) + assert result["attn_implementation"] == "flash_attention_2" + + def test_hub_kernel_passes_through(self): + data = {"attn_implementation": "kernels-community/flash-attn3"} + result = AxolotlInputConfig.normalize_attn_implementation(data) + assert result["attn_implementation"] == "kernels-community/flash-attn3" def test_no_attention_set_is_noop(self): - result = self._normalize({"some_other_config": True}) + result = AxolotlInputConfig.normalize_attn_implementation( + {"some_other_config": True} + ) assert result.get("attn_implementation") is None - # --- Gemma4 hybrid --- - def test_gemma4_hybrid_sets_flash(self): - """gemma4_hybrid_attn_impl should default attn_implementation to flash.""" - result = self._normalize({"gemma4_hybrid_attn_impl": True}) - assert result["attn_implementation"] == "flash" - assert result["flash_attention"] is True +class TestGemma4Hybrid: + """gemma4_hybrid_attn_impl defaults to flash_attention_2.""" - def test_gemma4_hybrid_does_not_override_explicit(self): - """If attn_implementation is already set, gemma4 should not override it.""" - result = self._normalize( + def test_gemma4_hybrid_defaults_to_fa2(self): + result = AxolotlInputConfig.normalize_attn_implementation( + {"gemma4_hybrid_attn_impl": True} + ) + assert result["attn_implementation"] == "flash_attention_2" + + def test_gemma4_hybrid_respects_explicit(self): + result = AxolotlInputConfig.normalize_attn_implementation( {"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"} ) assert result["attn_implementation"] == "sdpa" -class TestAttnCapabilityProperties: - """Test the capability properties on the normalizer data. +class TestFieldValidator: + """attn_implementation field_validator rejects short-form aliases.""" - Since these are @property on AxolotlInputConfig (a Pydantic model), - we test the underlying logic directly using the constant sets. - """ + def test_canonical_accepted(self): + for impl in CANONICAL_ATTN_IMPLS: + assert AxolotlInputConfig.validate_attn_implementation(impl) == impl - # --- attn_supports_packing --- + def test_hub_kernel_accepted(self): + for impl in ( + "kernels-community/flash-attn3", + "kernels-community/sage-attention", + "someorg/custom-kernel", + ): + assert AxolotlInputConfig.validate_attn_implementation(impl) == impl - @pytest.mark.parametrize("impl", ["flash", "flex", "xformers", "sage"]) - def test_supports_packing_true(self, impl): - assert impl not in _NON_PACKING_ATTN_IMPLS + def test_none_accepted(self): + assert AxolotlInputConfig.validate_attn_implementation(None) is None - @pytest.mark.parametrize("impl", ["eager", "sdpa", "s2", "fp8"]) - def test_supports_packing_false(self, impl): - assert impl in _NON_PACKING_ATTN_IMPLS + @pytest.mark.parametrize("alias", ["flash", "flex", "sdp"]) + def test_short_form_alias_rejected(self, alias): + with pytest.raises(ValueError, match="is not accepted"): + AxolotlInputConfig.validate_attn_implementation(alias) - def test_hub_kernel_supports_packing(self): - """Unknown hub kernels should default to packing-capable.""" - assert "kernels-community/flash-attn3" not in _NON_PACKING_ATTN_IMPLS + def test_unknown_without_slash_rejected(self): + with pytest.raises(ValueError, match="not a recognized backend"): + AxolotlInputConfig.validate_attn_implementation("not_a_real_backend") - # --- attn_uses_flash_lib --- - @pytest.mark.parametrize("impl", ["flash", "s2"]) - def test_uses_flash_lib_true(self, impl): - assert impl in FLASH_ATTN_LIB_IMPLS +class TestCapabilityTables: + """Capability tables are keyed by canonical names and cover the expected backends.""" @pytest.mark.parametrize( - "impl", ["eager", "sdpa", "xformers", "flex", "sage", "fp8"] + "impl", + [ + "flash_attention_2", + "flash_attention_3", + "flex_attention", + "xformers", + "sage", + ], ) - def test_uses_flash_lib_false(self, impl): - assert impl not in FLASH_ATTN_LIB_IMPLS + def test_supports_packing(self, impl): + assert impl in ATTN_IMPLS_SUPPORTING_PACKING - def test_hub_kernel_not_flash_lib(self): - """Hub kernels are HF-managed, not axolotl monkeypatch targets.""" - assert "kernels-community/flash-attn3" not in FLASH_ATTN_LIB_IMPLS + @pytest.mark.parametrize("impl", ["eager", "sdpa", "s2", "fp8"]) + def test_does_not_support_packing(self, impl): + assert impl not in ATTN_IMPLS_SUPPORTING_PACKING - # --- attn_needs_dtype_cast --- + @pytest.mark.parametrize("impl", ["flash_attention_2", "flash_attention_3", "s2"]) + def test_uses_flash_lib(self, impl): + assert impl in ATTN_IMPLS_USING_FLASH_LIB + + @pytest.mark.parametrize( + "impl", ["eager", "sdpa", "xformers", "flex_attention", "sage", "fp8"] + ) + def test_does_not_use_flash_lib(self, impl): + assert impl not in ATTN_IMPLS_USING_FLASH_LIB @pytest.mark.parametrize("impl", ["eager", "sdpa"]) def test_no_dtype_cast(self, impl): - assert impl in _NO_DTYPE_CAST_ATTN_IMPLS - - @pytest.mark.parametrize("impl", ["flash", "flex", "sage", "xformers", "s2", "fp8"]) - def test_needs_dtype_cast(self, impl): - assert impl not in _NO_DTYPE_CAST_ATTN_IMPLS - - -class TestAttnImplToHFMapping: - """Test that attn_implementation enum values map correctly to HF strings.""" - - # This dict mirrors _ATTN_IMPL_TO_HF in model.py - _ATTN_IMPL_TO_HF = { - "eager": "eager", - "flash": "flash_attention_2", - "sdpa": "sdpa", - "xformers": "xformers", - "flex": "flex_attention", - "sage": "sage", - "s2": "flash_attention_2", - "fp8": "sdpa", - } + assert impl in ATTN_IMPLS_WITHOUT_DTYPE_CAST @pytest.mark.parametrize( - "impl,expected_hf", + "impl", [ - ("eager", "eager"), - ("flash", "flash_attention_2"), - ("sdpa", "sdpa"), - ("xformers", "xformers"), - ("flex", "flex_attention"), - ("sage", "sage"), - ("s2", "flash_attention_2"), - ("fp8", "sdpa"), + "flash_attention_2", + "flash_attention_3", + "flex_attention", + "xformers", + "sage", + "s2", + "fp8", ], ) - def test_known_impl_maps_correctly(self, impl, expected_hf): - assert self._ATTN_IMPL_TO_HF[impl] == expected_hf + def test_needs_dtype_cast(self, impl): + assert impl not in ATTN_IMPLS_WITHOUT_DTYPE_CAST - def test_hub_kernel_falls_through(self): - """Hub kernel strings should pass through .get() unchanged.""" - hub_str = "kernels-community/flash-attn3" - result = self._ATTN_IMPL_TO_HF.get(hub_str, hub_str) - assert result == hub_str + def test_known_hub_kernels_classified(self): + assert "kernels-community/flash-attn3" in ATTN_IMPLS_SUPPORTING_PACKING + assert "kernels-community/flash-attn3" in ATTN_IMPLS_USING_FLASH_LIB + assert "kernels-community/sage-attention" in ATTN_IMPLS_SUPPORTING_PACKING def _xformers_available(): @@ -279,7 +212,7 @@ def _xformers_available(): class TestAttentionRegistration: - """Test that attention backends register correctly in HF's registries.""" + """Axolotl-owned backends register under their canonical names in HF's registries.""" @pytest.mark.skipif(not _xformers_available(), reason="xformers not available") def test_register_xformers(self): @@ -292,7 +225,6 @@ class TestAttentionRegistration: assert "xformers" in ALL_ATTENTION_FUNCTIONS assert "xformers" in ALL_MASK_ATTENTION_FUNCTIONS - # xformers mask should be the same function as flash_attention_2's mask assert ( ALL_MASK_ATTENTION_FUNCTIONS["xformers"] == ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"] @@ -315,7 +247,6 @@ class TestAttentionRegistration: @pytest.mark.skipif(not _xformers_available(), reason="xformers not available") def test_xformers_does_not_overwrite_fa2(self): - """Registering xformers should not modify the flash_attention_2 slot.""" from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS original_fa2 = ALL_ATTENTION_FUNCTIONS["flash_attention_2"] @@ -327,7 +258,6 @@ class TestAttentionRegistration: assert ALL_ATTENTION_FUNCTIONS["flash_attention_2"] is original_fa2 def test_sage_does_not_overwrite_fa2(self): - """Registering sage should not modify the flash_attention_2 slot.""" from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS original_fa2 = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]