16 KiB
attn-implementation-refactor branch review
Review target: attn-implementation-refactor (5 commits ahead of main, merge base 69904781).
Scope: 16 files, +682 / −96.
1. What the branch is trying to do
Collapse seven boolean attention flags (flash_attention, sdp_attention, xformers_attention, flex_attention, sage_attention, s2_attention, eager_attention) into a single attn_implementation field, with derived capability flags (attn_supports_packing, attn_uses_flash_lib, attn_needs_dtype_cast) for the gates that used to be ad-hoc OR-chains.
Mechanism: normalize_attn_implementation (a @model_validator(mode="before") on AxolotlInputConfig) maps bidirectionally between the new field and the legacy flags, with a priority list for legacy combos (s2 + flash → s2), and then computes the three capability flags from frozen sets in enums.py.
Adjacent changes: xformers and sage now register as their own entries in ALL_ATTENTION_FUNCTIONS (with FA2 mask behavior) instead of stomping the flash_attention_2 slot. New fp8 backend wires torchao.prototype.attention.apply_low_precision_attention in apply_post_model_load_patches.
2. Target design
cfg.attn_implementation is the single source of truth on the validated config.
- Its type is
str | None. Accepted values are canonical names only — no short-form aliases:- HF-native:
eager,sdpa,flash_attention_2,flash_attention_3,flex_attention. (flash_attention_3is net-new to axolotl — the current branch only encodesflash_attention_2under the short nameflash.) - Axolotl-owned (registered into
ALL_ATTENTION_FUNCTIONSunder exactly these names):xformers,sage,s2,fp8. - Hub-kernel paths:
kernels-community/sage-attention,kernels-community/flash-attn3, etc. — passthrough. Known-kernel allowlist inenums.pyclassifies the common ones into the capability tables. Short forms likeflash,fa2,fa3,sdp,flexare rejected (Pydantic validation error with a pointer to the canonical name).
- HF-native:
model.py:_set_attention_configpassescfg.attn_implementationto HF verbatim — no_ATTN_IMPL_TO_HFtranslation dict needed.- Legacy booleans (
flash_attention: true,sdp_attention: true, …) are the only input aliases, kept for backwards compatibility. The normalizer maps them to the canonicalattn_implementationvalue, emits a one-timeDeprecationWarningper flag, and removes them fromdataso they're never readable on the validatedcfg.deprecated=Trueon each Field surfaces this in JSON schema. Mapping is 1:1 with the current legacy-flag semantics (flash_attention → flash_attention_2,sdp_attention → sdpa,flex_attention → flex_attention,xformers_attention → xformers,sage_attention → sage,s2_attention → s2,eager_attention → eager). - Capability flags (
attn_supports_packing,attn_uses_flash_lib,attn_needs_dtype_cast) are@computed_fieldon the model, not settable inputs. Lookup is keyed by the canonicalattn_implementationstring. - Unknown / user-supplied strings (custom hub kernels) pass through to HF but get conservative capability defaults (packing=False, flash-lib=False, dtype-cast=True). Known hub kernels axolotl can classify live in a small allowlist.
- Downstream consumers read only
cfg.attn_implementationand the capability flags. Nocfg.flash_attention,cfg.xformers_attention, etc. anywhere insrc/.
This is strictly what the branch is already trying to do — the gaps below are places it hasn't landed that goal yet.
3. Gaps and holes
A. Legacy flags are still parallel state, not input-only
- The normalizer sets the legacy flags on
data(impl_to_flag[attn_impl]branch). It does not delete them. Socfg.flash_attentionis still truthy after validation, and downstream code still reads it (see G). - Short-form enum values (
flash,sdpa,fp8) are persisted as-is oncfg.attn_implementation, which is whymodel.pyneeds_ATTN_IMPL_TO_HFto translate before passing to HF. Source-of-truth implies canonicalize at normalize-time, not translate at consume-time. - Legacy flag +
attn_implementation(consistent combo, e.g.attn_implementation: flash + flash_attention: true) emits no deprecation warning — only legacy-only path warns. - Legacy Field descriptions (
xformers_attention,sdp_attention, etc.) don't havedeprecated=True, so JSON schema still advertises them as first-class.
B. Validators that still only check the legacy flag
check_ebft_activation_offloading(validation.py:1607-1619) reads onlydata.get("flex_attention"). Users onattn_implementation: flex_attentionbypass the incompatibility check.check_sample_packing_without_attention(validation.py:188-203) early-returns whenattn_implementationis set but never validates the chosen backend actually supports packing.attn_implementation: eager + sample_packing: truesilently passes; the old legacy-flag check warned.
C. Non-enum strings fall through the capability tables
- HF-native
"flash_attention_2"is neither inimpl_to_flagnorFLASH_ATTN_LIB_IMPLS. A user copy-pasting from HF docs getsattn_uses_flash_lib=False, silently disabling FA4 auto-apply, LLaMA flash hijack,_patch_attention(btlm, stablelm_epoch, mistral3, llava), and_apply_flash_attention_peft_patches. - Hub kernel strings (
kernels-community/flash-attn3,kernels-community/sage-attention) default toattn_supports_packing=True(silently enters multipack with varlenposition_ids— correctness depends on the kernel honoring them) andattn_uses_flash_lib=False(socontext_parallel_size > 1raises "requires flash attention" even for FA3 hub kernels). - Conflict trap for hub-kernel + legacy flag (
config.py:1414-1419):attn_implementation: kernels-community/flash-attn3 + flash_attention: truealways raises, becauseimpl_to_flag.get(custom) is Noneand the loop treatsflag != Noneas conflict. Common combo in existing YAMLs breaks hard on upgrade.
D. Silent behaviour change for xformers
- Old
_apply_flash_attention_patchesdidself.cfg.flash_attention = Trueforxformers + sample_packing. The new version doesn't, and xformers is not inFLASH_ATTN_LIB_IMPLS. Consumers that keyed offcfg.flash_attentionnow see falsy for xformers, silently dropping_patch_attention(btlm / stablelm_epoch+packing / mistral3 / llava model-type FA patches). Some of this is arguably correct cleanup (xformers has its own HF registry entry now), but the btlm/stablelm/mistral3 regression is not called out and not tested. Decide consciously, not by omission.
E. Capability flags are writable Pydantic fields, not computed
attn_supports_packing,attn_uses_flash_lib,attn_needs_dtype_castare declaredbool | None = Field(default=None)onAxolotlInputConfig. YAML is not rejected — a user can setattn_uses_flash_lib: trueand override the normalizer.
F. Validator ordering (not covered by tests)
AttentionValidationMixin.check_attention_fields(inherited,mode="before") andnormalize_attn_implementation(subclass,mode="before") both run duringmodel_validatorphase. Pydantic MRO may run the inherited one first. For legacy-onlys2_attention: true + flash_attention: true(the testtest_s2_plus_flash_maps_to_s2asserts this maps tos2), the inherited check may raise "multiple attention implementations set" before the normalizer runs. The test calls the classmethod directly and does not build the model, so this is unverified either way.
G. Remaining legacy reads in src/
src/axolotl/integrations/lm_eval/cli.py:120readscfg.flash_attention. Works forattn_implementation=flashonly.tests/e2e/multigpu/test_llama.py:524-526writescfg.flash_attention = True/cfg.flex_attention = True. Stale pattern.- Dual-check idioms in
config.py(lines 1464, 1478, 1570, 1586, 1774) andvalidation.py(198, 209, 221, 850, 1586, 1611) —data.get("x_attention") or data.get("attn_implementation") == "x"— are redundant once legacy flags are input-only; remove them.
H. fp8 operational risk
- The
fp8docstring documents hard requirements (PyTorch ≥ 2.11, SM90+, flash-attn with FA3, torchao ≥ 0.17.0) and a runtime constraint (config.use_cache = False). None are validated — misconfig surfaces as a torchao runtime error.xformersandsageavailability/compute-capability guards exist;fp8should match.
I. Test coverage gaps
test_attn_implementation.pyexercises the classmethod in isolation plus the constant sets. It does not:
- Build a full
AxolotlInputConfig(**data)(so validator ordering — item 12 — is untested). - Verify capability flags can't be overridden from YAML (item 11).
- Cover
check_sample_packing_without_attentionwithattn_implementation: eager(item 6). - Cover
check_ebft_activation_offloadingwithattn_implementation: flex_attention(item 5). - Cover hub-kernel + legacy flag combo (item 9).
- Cover
flash_attention_2canonicalization (item 7).
4. Fix plan
Four phases, each commit-sized. Phases 1–2 are local and low-risk; phase 3 is the behaviour-changing cleanup; phase 4 is tests + docs.
Phase 1 — Lock down the data model
-
Drop the
AttnImplementationenum.attn_implementationbecomesstr | None, validated against a canonical allowlist (eager,sdpa,flash_attention_2,flash_attention_3,flex_attention,xformers,sage,s2,fp8) or a hub-kernel path (startswith("kernels-")/ contains/). Reject short-form strings likeflash/fa2/sdp/flexwith an explicit error pointing at the canonical name. -
Rewrite
normalize_attn_implementationso its only job is mapping legacy booleans → canonicalattn_implementation(for BC). Mapping is fixed:flash_attention → flash_attention_2sdp_attention → sdpaflex_attention → flex_attentionxformers_attention → xformerssage_attention → sages2_attention → s2eager_attention → eagerPriority for legacy combos stays as in the current branch (s2 > sage > xformers > flex > flash > sdp > eager). Emit a one-timeDeprecationWarningper unique legacy flag seen. After mapping, delete the legacy flag keys fromdataso they never appear on validatedcfg. If both a canonicalattn_implementationand any legacy flag are set, raise (no silent precedence).
Merge
AttentionValidationMixin.check_attention_fieldsinto this normalizer and delete the mixin method. Pydantic v2 runs inheritedmode="before"validators before subclass ones per MRO, so leaving them as siblings causes the inherited check to reject legacy combos likes2 + flashbefore the normalizer can map them. One validator, one source of conflict detection.Fix the gemma4-hybrid path: change
data["attn_implementation"] = "flash"todata["attn_implementation"] = "flash_attention_2"(the short name no longer validates after step 1). -
Convert
attn_supports_packing,attn_uses_flash_lib,attn_needs_dtype_castto@computed_field. The three capability tables move toenums.pyas module constants keyed by the canonicalattn_implementationstring (includingflash_attention_3— missing from the current branch — and known hub kernels):- Packing-capable:
{flash_attention_2, flash_attention_3, flex_attention, xformers, sage, kernels-community/flash-attn3, kernels-community/sage-attention}. - Flash-lib (axolotl's monkeypatch targets):
{flash_attention_2, flash_attention_3, s2, kernels-community/flash-attn3}. - No-dtype-cast:
{eager, sdpa}. Unknown strings: conservative defaults (packing=False, flash_lib=False, dtype_cast=True).
- Packing-capable:
-
Delete
_ATTN_IMPL_TO_HFfrommodel.pyand passcfg.attn_implementationstraight through. The gemma4-hybrid branch continues to override toflash_attention_2before passing to HF. -
deprecated=Trueon each legacy boolean Field so JSON schema + Pydantic surface the deprecation.
Phase 2 — Fix the validators
check_sample_packing_without_attention: drop the early-return and gate onattn_supports_packing. Warn (or raise — pick one and be consistent) if packing is enabled with a non-packing backend.check_ebft_activation_offloading: replacedata.get("flex_attention")withattn_implementation == "flex_attention".- Sweep items (item 15): remove every
data.get("x_attention") or data.get("attn_implementation") == "x"dual-check — after phase 1 the legacy side is alwaysNone. Reduces ~10 lines of noise and eliminates the "which side wins" class of bugs. - fp8 preflight (item 16): require
env_capabilities.compute_capability ≥ sm_90,torch_version ≥ 2.11, andtorchao_version ≥ 0.17. Warn ifuse_cacheisn't explicitlyFalse.
Phase 3 — Migrate remaining consumers
lm_eval/cli.py:120→flash_attention=cfg.attn_uses_flash_lib.lm_eval/__init__.py:26currently reads(cfg.attn_implementation == "flash")— after canonicalization"flash"is never stored, so this evaluatesFalsefor every backend. Change tocfg.attn_uses_flash_lib.validation.py:1137-1142(NPU check) currently iterates["flash_attention", "sdp_attention", "s2_attention"]as string keys. Replace withcfg.attn_implementation in {"flash_attention_2", "flash_attention_3", "sdpa", "s2"}or the equivalent canonical-string set.tests/e2e/multigpu/test_llama.py:524-526→cfg.attn_implementation = "flash_attention_2"/"flex_attention".- Xformers decision (item 10): the old
cfg.flash_attention = Trueside-effect activated_patch_attentionfor btlm/stablelm_epoch+packing/mistral3/llava. Two choices:- Add
xformersto the set that gates_patch_attention(restore old behaviour, keeps patches live). - Document that those patches don't apply to xformers post-refactor and drop the paths if they're dead. Pick one explicitly and leave a commit note. Do not leave it as silent breakage.
- Add
- Add a repo-level check (
tests/test_no_legacy_attn_reads.pyor a ruff/grep pre-commit) that fails if anything outsideconfig.py's normalizer readscfg.flash_attention/cfg.sdp_attention/ etc. Keeps the invariant from rotting.
Phase 4 — Tests + docs
- Rewrite
test_attn_implementation.pyto build fullAxolotlInputConfig(**data), not just the classmethod. Covers validator ordering and the Pydantic-field-override issue. - Add one test per gap closed above:
attn_implementation: eager + sample_packing;attn_implementation: flex_attention + activation_offloading; short-formflashrejected;flash_attention_2passthrough;kernels-community/flash-attn3capability lookup;attn_uses_flash_lib: truein YAML rejected; legacy boolean emitsDeprecationWarningand is absent from validatedcfg; fp8 preflight failures. - Update
docs/attention.qmdfor the singleattn_implementationfield + the deprecation table for legacy flags. One-paragraph migration note in the changelog. examples/contains ~170 YAML files using legacy flags (flash_attention: trueetc.). They still validate post-refactor (normalizer maps them with deprecation), but a follow-up sweep to convert them toattn_implementation: flash_attention_2is worth scheduling — call this out in the migration note so users know examples will be migrated on a later pass.
5. Ordering & risk
- Phase 1 is the keystone: it's the largest diff but each step is mechanical once the alias map is in place. No behaviour change for any consumer that was using
attn_implementationcorrectly; behaviour change only for consumers that were reading legacy flags (phase 3 step 13 is the explicit decision point). - Phase 2 is independent of phase 1 and can land first as a quick safety net.
- Phase 3 step 13 is the only judgment call — flag for review before choosing.
- Total: ~10-13 commits beyond what's on the branch, each scoped and individually revertable.