make attn_implementation the single source of truth

This commit is contained in:
Wing Lian
2026-04-23 21:17:10 +00:00
parent 35d43fe141
commit 2579c496d5
10 changed files with 491 additions and 387 deletions

142
ATTN_REFACTOR_REVIEW.md Normal file
View File

@@ -0,0 +1,142 @@
# `attn-implementation-refactor` branch review
Review target: `attn-implementation-refactor` (5 commits ahead of main, merge base `69904781`).
Scope: 16 files, +682 / 96.
## 1. What the branch is trying to do
Collapse seven boolean attention flags (`flash_attention`, `sdp_attention`, `xformers_attention`, `flex_attention`, `sage_attention`, `s2_attention`, `eager_attention`) into a single `attn_implementation` field, with derived capability flags (`attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast`) for the gates that used to be ad-hoc OR-chains.
Mechanism: `normalize_attn_implementation` (a `@model_validator(mode="before")` on `AxolotlInputConfig`) maps bidirectionally between the new field and the legacy flags, with a priority list for legacy combos (`s2 + flash → s2`), and then computes the three capability flags from frozen sets in `enums.py`.
Adjacent changes: `xformers` and `sage` now register as their own entries in `ALL_ATTENTION_FUNCTIONS` (with FA2 mask behavior) instead of stomping the `flash_attention_2` slot. New `fp8` backend wires `torchao.prototype.attention.apply_low_precision_attention` in `apply_post_model_load_patches`.
## 2. Target design
**`cfg.attn_implementation` is the single source of truth on the validated config.**
- Its type is `str | None`. Accepted values are **canonical names only** — no short-form aliases:
- HF-native: `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention`. (`flash_attention_3` is net-new to axolotl — the current branch only encodes `flash_attention_2` under the short name `flash`.)
- Axolotl-owned (registered into `ALL_ATTENTION_FUNCTIONS` under exactly these names): `xformers`, `sage`, `s2`, `fp8`.
- Hub-kernel paths: `kernels-community/sage-attention`, `kernels-community/flash-attn3`, etc. — passthrough. Known-kernel allowlist in `enums.py` classifies the common ones into the capability tables.
Short forms like `flash`, `fa2`, `fa3`, `sdp`, `flex` are rejected (Pydantic validation error with a pointer to the canonical name).
- `model.py:_set_attention_config` passes `cfg.attn_implementation` to HF verbatim — no `_ATTN_IMPL_TO_HF` translation dict needed.
- Legacy booleans (`flash_attention: true`, `sdp_attention: true`, …) are the **only** input aliases, kept for backwards compatibility. The normalizer maps them to the canonical `attn_implementation` value, emits a one-time `DeprecationWarning` per flag, and removes them from `data` so they're never readable on the validated `cfg`. `deprecated=True` on each Field surfaces this in JSON schema. Mapping is 1:1 with the current legacy-flag semantics (`flash_attention → flash_attention_2`, `sdp_attention → sdpa`, `flex_attention → flex_attention`, `xformers_attention → xformers`, `sage_attention → sage`, `s2_attention → s2`, `eager_attention → eager`).
- Capability flags (`attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast`) are **`@computed_field` on the model**, not settable inputs. Lookup is keyed by the canonical `attn_implementation` string.
- Unknown / user-supplied strings (custom hub kernels) pass through to HF but get **conservative capability defaults** (packing=False, flash-lib=False, dtype-cast=True). Known hub kernels axolotl can classify live in a small allowlist.
- Downstream consumers read *only* `cfg.attn_implementation` and the capability flags. No `cfg.flash_attention`, `cfg.xformers_attention`, etc. anywhere in `src/`.
This is strictly what the branch is already *trying* to do — the gaps below are places it hasn't landed that goal yet.
## 3. Gaps and holes
### A. Legacy flags are still parallel state, not input-only
1. The normalizer *sets* the legacy flags on `data` (`impl_to_flag[attn_impl]` branch). It does not delete them. So `cfg.flash_attention` is still truthy after validation, and downstream code still reads it (see G).
2. Short-form enum values (`flash`, `sdpa`, `fp8`) are persisted as-is on `cfg.attn_implementation`, which is why `model.py` needs `_ATTN_IMPL_TO_HF` to translate before passing to HF. Source-of-truth implies canonicalize at normalize-time, not translate at consume-time.
3. Legacy flag + `attn_implementation` (consistent combo, e.g. `attn_implementation: flash + flash_attention: true`) emits no deprecation warning — only legacy-only path warns.
4. Legacy Field descriptions (`xformers_attention`, `sdp_attention`, etc.) don't have `deprecated=True`, so JSON schema still advertises them as first-class.
### B. Validators that still only check the legacy flag
5. **`check_ebft_activation_offloading`** (`validation.py:1607-1619`) reads only `data.get("flex_attention")`. Users on `attn_implementation: flex_attention` bypass the incompatibility check.
6. **`check_sample_packing_without_attention`** (`validation.py:188-203`) early-returns when `attn_implementation` is set but never validates the chosen backend actually supports packing. `attn_implementation: eager + sample_packing: true` silently passes; the old legacy-flag check warned.
### C. Non-enum strings fall through the capability tables
7. **HF-native `"flash_attention_2"`** is neither in `impl_to_flag` nor `FLASH_ATTN_LIB_IMPLS`. A user copy-pasting from HF docs gets `attn_uses_flash_lib=False`, silently disabling FA4 auto-apply, LLaMA flash hijack, `_patch_attention` (btlm, stablelm_epoch, mistral3, llava), and `_apply_flash_attention_peft_patches`.
8. **Hub kernel strings** (`kernels-community/flash-attn3`, `kernels-community/sage-attention`) default to `attn_supports_packing=True` (silently enters multipack with varlen `position_ids` — correctness depends on the kernel honoring them) and `attn_uses_flash_lib=False` (so `context_parallel_size > 1` raises "requires flash attention" even for FA3 hub kernels).
9. **Conflict trap for hub-kernel + legacy flag** (`config.py:1414-1419`): `attn_implementation: kernels-community/flash-attn3 + flash_attention: true` always raises, because `impl_to_flag.get(custom) is None` and the loop treats `flag != None` as conflict. Common combo in existing YAMLs breaks hard on upgrade.
### D. Silent behaviour change for xformers
10. Old `_apply_flash_attention_patches` did `self.cfg.flash_attention = True` for `xformers + sample_packing`. The new version doesn't, and xformers is not in `FLASH_ATTN_LIB_IMPLS`. Consumers that keyed off `cfg.flash_attention` now see falsy for xformers, silently dropping `_patch_attention` (btlm / stablelm_epoch+packing / mistral3 / llava model-type FA patches). Some of this is arguably correct cleanup (xformers has its own HF registry entry now), but the btlm/stablelm/mistral3 regression is not called out and not tested. Decide consciously, not by omission.
### E. Capability flags are writable Pydantic fields, not computed
11. `attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast` are declared `bool | None = Field(default=None)` on `AxolotlInputConfig`. YAML is not rejected — a user can set `attn_uses_flash_lib: true` and override the normalizer.
### F. Validator ordering (not covered by tests)
12. `AttentionValidationMixin.check_attention_fields` (inherited, `mode="before"`) and `normalize_attn_implementation` (subclass, `mode="before"`) both run during `model_validator` phase. Pydantic MRO may run the inherited one first. For legacy-only `s2_attention: true + flash_attention: true` (the test `test_s2_plus_flash_maps_to_s2` asserts this maps to `s2`), the inherited check may raise "multiple attention implementations set" before the normalizer runs. The test calls the classmethod directly and does not build the model, so this is unverified either way.
### G. Remaining legacy reads in `src/`
13. `src/axolotl/integrations/lm_eval/cli.py:120` reads `cfg.flash_attention`. Works for `attn_implementation=flash` only.
14. `tests/e2e/multigpu/test_llama.py:524-526` writes `cfg.flash_attention = True` / `cfg.flex_attention = True`. Stale pattern.
15. Dual-check idioms in `config.py` (lines 1464, 1478, 1570, 1586, 1774) and `validation.py` (198, 209, 221, 850, 1586, 1611) — `data.get("x_attention") or data.get("attn_implementation") == "x"` — are redundant once legacy flags are input-only; remove them.
### H. fp8 operational risk
16. The `fp8` docstring documents hard requirements (PyTorch ≥ 2.11, SM90+, flash-attn with FA3, torchao ≥ 0.17.0) and a runtime constraint (`config.use_cache = False`). None are validated — misconfig surfaces as a torchao runtime error. `xformers` and `sage` availability/compute-capability guards exist; `fp8` should match.
### I. Test coverage gaps
17. `test_attn_implementation.py` exercises the classmethod in isolation plus the constant sets. It does **not**:
- Build a full `AxolotlInputConfig(**data)` (so validator ordering — item 12 — is untested).
- Verify capability flags can't be overridden from YAML (item 11).
- Cover `check_sample_packing_without_attention` with `attn_implementation: eager` (item 6).
- Cover `check_ebft_activation_offloading` with `attn_implementation: flex_attention` (item 5).
- Cover hub-kernel + legacy flag combo (item 9).
- Cover `flash_attention_2` canonicalization (item 7).
## 4. Fix plan
Four phases, each commit-sized. Phases 12 are local and low-risk; phase 3 is the behaviour-changing cleanup; phase 4 is tests + docs.
### Phase 1 — Lock down the data model
1. Drop the `AttnImplementation` enum. `attn_implementation` becomes `str | None`, validated against a canonical allowlist (`eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention`, `xformers`, `sage`, `s2`, `fp8`) **or** a hub-kernel path (`startswith("kernels-")` / contains `/`). Reject short-form strings like `flash` / `fa2` / `sdp` / `flex` with an explicit error pointing at the canonical name.
2. Rewrite `normalize_attn_implementation` so its only job is mapping **legacy booleans → canonical `attn_implementation`** (for BC). Mapping is fixed:
- `flash_attention → flash_attention_2`
- `sdp_attention → sdpa`
- `flex_attention → flex_attention`
- `xformers_attention → xformers`
- `sage_attention → sage`
- `s2_attention → s2`
- `eager_attention → eager`
Priority for legacy combos stays as in the current branch (`s2 > sage > xformers > flex > flash > sdp > eager`). Emit a one-time `DeprecationWarning` per unique legacy flag seen. After mapping, delete the legacy flag keys from `data` so they never appear on validated `cfg`. If both a canonical `attn_implementation` and any legacy flag are set, raise (no silent precedence).
**Merge `AttentionValidationMixin.check_attention_fields` into this normalizer and delete the mixin method.** Pydantic v2 runs inherited `mode="before"` validators before subclass ones per MRO, so leaving them as siblings causes the inherited check to reject legacy combos like `s2 + flash` before the normalizer can map them. One validator, one source of conflict detection.
**Fix the gemma4-hybrid path**: change `data["attn_implementation"] = "flash"` to `data["attn_implementation"] = "flash_attention_2"` (the short name no longer validates after step 1).
3. Convert `attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast` to `@computed_field`. The three capability tables move to `enums.py` as module constants keyed by the canonical `attn_implementation` string (including `flash_attention_3` — missing from the current branch — and known hub kernels):
- Packing-capable: `{flash_attention_2, flash_attention_3, flex_attention, xformers, sage, kernels-community/flash-attn3, kernels-community/sage-attention}`.
- Flash-lib (axolotl's monkeypatch targets): `{flash_attention_2, flash_attention_3, s2, kernels-community/flash-attn3}`.
- No-dtype-cast: `{eager, sdpa}`.
Unknown strings: conservative defaults (`packing=False, flash_lib=False, dtype_cast=True`).
4. Delete `_ATTN_IMPL_TO_HF` from `model.py` and pass `cfg.attn_implementation` straight through. The gemma4-hybrid branch continues to override to `flash_attention_2` before passing to HF.
5. `deprecated=True` on each legacy boolean Field so JSON schema + Pydantic surface the deprecation.
### Phase 2 — Fix the validators
6. `check_sample_packing_without_attention`: drop the early-return and gate on `attn_supports_packing`. Warn (or raise — pick one and be consistent) if packing is enabled with a non-packing backend.
7. `check_ebft_activation_offloading`: replace `data.get("flex_attention")` with `attn_implementation == "flex_attention"`.
8. Sweep items (item 15): remove every `data.get("x_attention") or data.get("attn_implementation") == "x"` dual-check — after phase 1 the legacy side is always `None`. Reduces ~10 lines of noise and eliminates the "which side wins" class of bugs.
9. fp8 preflight (item 16): require `env_capabilities.compute_capability ≥ sm_90`, `torch_version ≥ 2.11`, and `torchao_version ≥ 0.17`. Warn if `use_cache` isn't explicitly `False`.
### Phase 3 — Migrate remaining consumers
10. `lm_eval/cli.py:120``flash_attention=cfg.attn_uses_flash_lib`.
11. `lm_eval/__init__.py:26` currently reads `(cfg.attn_implementation == "flash")` — after canonicalization `"flash"` is never stored, so this evaluates `False` for every backend. Change to `cfg.attn_uses_flash_lib`.
12. `validation.py:1137-1142` (NPU check) currently iterates `["flash_attention", "sdp_attention", "s2_attention"]` as string keys. Replace with `cfg.attn_implementation in {"flash_attention_2", "flash_attention_3", "sdpa", "s2"}` or the equivalent canonical-string set.
13. `tests/e2e/multigpu/test_llama.py:524-526``cfg.attn_implementation = "flash_attention_2"` / `"flex_attention"`.
14. **Xformers decision** (item 10): the old `cfg.flash_attention = True` side-effect activated `_patch_attention` for btlm/stablelm_epoch+packing/mistral3/llava. Two choices:
- Add `xformers` to the set that gates `_patch_attention` (restore old behaviour, keeps patches live).
- Document that those patches don't apply to xformers post-refactor and drop the paths if they're dead.
Pick one explicitly and leave a commit note. Do not leave it as silent breakage.
15. Add a repo-level check (`tests/test_no_legacy_attn_reads.py` or a ruff/grep pre-commit) that fails if anything outside `config.py`'s normalizer reads `cfg.flash_attention` / `cfg.sdp_attention` / etc. Keeps the invariant from rotting.
### Phase 4 — Tests + docs
14. Rewrite `test_attn_implementation.py` to build full `AxolotlInputConfig(**data)`, not just the classmethod. Covers validator ordering and the Pydantic-field-override issue.
15. Add one test per gap closed above: `attn_implementation: eager + sample_packing`; `attn_implementation: flex_attention + activation_offloading`; short-form `flash` rejected; `flash_attention_2` passthrough; `kernels-community/flash-attn3` capability lookup; `attn_uses_flash_lib: true` in YAML rejected; legacy boolean emits `DeprecationWarning` and is absent from validated `cfg`; fp8 preflight failures.
16. Update `docs/attention.qmd` for the single `attn_implementation` field + the deprecation table for legacy flags. One-paragraph migration note in the changelog.
17. `examples/` contains ~170 YAML files using legacy flags (`flash_attention: true` etc.). They still validate post-refactor (normalizer maps them with deprecation), but a follow-up sweep to convert them to `attn_implementation: flash_attention_2` is worth scheduling — call this out in the migration note so users know examples will be migrated on a later pass.
## 5. Ordering & risk
- Phase 1 is the keystone: it's the largest diff but each step is mechanical once the alias map is in place. No behaviour change for any consumer that was using `attn_implementation` correctly; behaviour change only for consumers that were reading legacy flags (phase 3 step 13 is the explicit decision point).
- Phase 2 is independent of phase 1 and can land first as a quick safety net.
- Phase 3 step 13 is the only judgment call — flag for review before choosing.
- Total: ~10-13 commits beyond what's on the branch, each scoped and individually revertable.

View File

@@ -502,11 +502,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention, # Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
# supported multipack models, or non-flash-attention llama # supported multipack models, or non-flash-attention llama
if ( if (
self.cfg.attn_implementation == "flex" self.cfg.attn_implementation == "flex_attention"
or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
or ( or (
self.cfg.model_config_type in ["llama"] self.cfg.model_config_type in ["llama"]
and self.cfg.attn_implementation != "flash" and self.cfg.attn_implementation != "flash_attention_2"
) )
): ):
collator = V2BatchSamplerDataCollatorForSeq2Seq collator = V2BatchSamplerDataCollatorForSeq2Seq

View File

@@ -23,7 +23,7 @@ class LMEvalPlugin(BasePlugin):
for lm_eval_args in build_lm_eval_command( for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks, cfg.lm_eval_tasks,
bfloat16=cfg.bfloat16 or cfg.bf16, bfloat16=cfg.bfloat16 or cfg.bf16,
flash_attention=(cfg.attn_implementation == "flash"), flash_attention=cfg.attn_uses_flash_lib,
output_dir=cfg.output_dir, output_dir=cfg.output_dir,
batch_size=cfg.lm_eval_batch_size, batch_size=cfg.lm_eval_batch_size,
wandb_project=cfg.wandb_project, wandb_project=cfg.wandb_project,

View File

@@ -628,33 +628,25 @@ class ModelLoader:
) )
def _set_attention_config(self): def _set_attention_config(self):
"""Sample packing uses custom FA2 patch""" # s2 and fp8 need a different HF backend at load time than their
# Map attn_implementation enum values to HF attn_implementation strings. # canonical name: s2 patches FA2 internals, so load under FA2; fp8
# xformers/sage are registered in ALL_ATTENTION_FUNCTIONS and # replaces F.scaled_dot_product_attention post-load, so load under sdpa.
# ALL_MASK_ATTENTION_FUNCTIONS under their own names with FA2 mask # Every other canonical name (and hub-kernel paths) is passed through
# behavior, so they no longer need to masquerade as flash_attention_2. # verbatim — xformers/sage/flash_attention_* are registered under their
# s2 still uses flash_attention_2 because it modifies FA2 internals. # own names in ALL_ATTENTION_FUNCTIONS before model load.
# Hub kernel strings (e.g. "kernels-community/flash-attn3") fall _LOAD_TIME_OVERRIDE = {"s2": "flash_attention_2", "fp8": "sdpa"}
# through the .get() and are passed to HF unchanged.
_ATTN_IMPL_TO_HF = {
"eager": "eager",
"flash": "flash_attention_2",
"sdpa": "sdpa",
"xformers": "xformers",
"flex": "flex_attention",
"sage": "sage",
"s2": "flash_attention_2",
"fp8": "sdpa",
}
if self.cfg.gemma4_hybrid_attn_impl: if self.cfg.gemma4_hybrid_attn_impl:
# Load model with flash_attention_2 for sliding window layers; # Load with flash_attention_2 for sliding-window layers; global
# global layers will be patched to sdpa post-load. # layers are swapped to sdpa post-load.
self.model_kwargs["attn_implementation"] = "flash_attention_2" hf_impl = "flash_attention_2"
self.model_config._attn_implementation = "flash_attention_2"
elif self.cfg.attn_implementation: elif self.cfg.attn_implementation:
hf_impl = _ATTN_IMPL_TO_HF.get( hf_impl = _LOAD_TIME_OVERRIDE.get(
self.cfg.attn_implementation, self.cfg.attn_implementation self.cfg.attn_implementation, self.cfg.attn_implementation
) )
else:
hf_impl = None
if hf_impl is not None:
self.model_kwargs["attn_implementation"] = hf_impl self.model_kwargs["attn_implementation"] = hf_impl
self.model_config._attn_implementation = hf_impl self.model_config._attn_implementation = hf_impl

View File

@@ -333,7 +333,7 @@ class PatchManager:
def _apply_flex_attention_patches(self): def _apply_flex_attention_patches(self):
"""Apply patches for flexible attention.""" """Apply patches for flexible attention."""
if self.cfg.attn_implementation == "flex": if self.cfg.attn_implementation == "flex_attention":
from axolotl.monkeypatch.attention.flex_attn import ( from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_wrapper, patch_flex_wrapper,
) )

View File

@@ -207,7 +207,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
# Mistral's official FA implementation requires left padding # Mistral's official FA implementation requires left padding
if ( if (
cfg.is_mistral_derived_model cfg.is_mistral_derived_model
and cfg.attn_implementation == "flash" and cfg.attn_implementation == "flash_attention_2"
and not cfg.sample_packing and not cfg.sample_packing
): ):
tokenizer.padding_side = "left" tokenizer.padding_side = "left"

View File

@@ -10,7 +10,9 @@ from pydantic import (
BaseModel, BaseModel,
Field, Field,
StringConstraints, StringConstraints,
computed_field,
field_serializer, field_serializer,
field_validator,
model_validator, model_validator,
) )
@@ -28,10 +30,12 @@ from axolotl.utils.schemas.datasets import (
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
from axolotl.utils.schemas.enums import ( from axolotl.utils.schemas.enums import (
_NO_DTYPE_CAST_ATTN_IMPLS, ATTN_IMPLS_SUPPORTING_PACKING,
_NON_PACKING_ATTN_IMPLS, ATTN_IMPLS_USING_FLASH_LIB,
FLASH_ATTN_LIB_IMPLS, ATTN_IMPLS_WITHOUT_DTYPE_CAST,
AttnImplementation, CANONICAL_ATTN_IMPLS,
LEGACY_ATTN_FLAG_TO_IMPL,
SHORT_FORM_ALIAS_TO_CANONICAL,
ChatTemplate, ChatTemplate,
RingAttnFunc, RingAttnFunc,
RLType, RLType,
@@ -739,28 +743,35 @@ class AxolotlInputConfig(
xformers_attention: bool | None = Field( xformers_attention: bool | None = Field(
default=None, default=None,
deprecated="Use `attn_implementation: xformers` instead.",
json_schema_extra={ json_schema_extra={
"description": "Whether to use xformers attention patch https://github.com/facebookresearch/xformers" "description": "[DEPRECATED] Use `attn_implementation: xformers`. https://github.com/facebookresearch/xformers"
}, },
) )
sdp_attention: bool | None = Field( sdp_attention: bool | None = Field(
default=None, default=None,
deprecated="Use `attn_implementation: sdpa` instead.",
json_schema_extra={ json_schema_extra={
"description": "Whether to use scaled-dot-product attention https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html" "description": "[DEPRECATED] Use `attn_implementation: sdpa`."
}, },
) )
s2_attention: bool | None = Field( s2_attention: bool | None = Field(
default=None, default=None,
deprecated="Use `attn_implementation: s2` instead.",
json_schema_extra={ json_schema_extra={
"description": "Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf" "description": "[DEPRECATED] Use `attn_implementation: s2`. Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf"
}, },
) )
flex_attention: bool | None = None flex_attention: bool | None = Field(
default=None,
deprecated="Use `attn_implementation: flex_attention` instead.",
)
flex_attn_compile_kwargs: dict[str, Any] | None = None flex_attn_compile_kwargs: dict[str, Any] | None = None
flash_attention: bool | None = Field( flash_attention: bool | None = Field(
default=None, default=None,
deprecated="Use `attn_implementation: flash_attention_2` instead.",
json_schema_extra={ json_schema_extra={
"description": "Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention" "description": "[DEPRECATED] Use `attn_implementation: flash_attention_2`. https://github.com/Dao-AILab/flash-attention"
}, },
) )
flash_attn_cross_entropy: bool | None = Field( flash_attn_cross_entropy: bool | None = Field(
@@ -787,17 +798,26 @@ class AxolotlInputConfig(
) )
sage_attention: bool | None = Field( sage_attention: bool | None = Field(
default=None, default=None,
deprecated="Use `attn_implementation: sage` instead.",
json_schema_extra={ json_schema_extra={
"description": "Whether to use SageAttention https://github.com/thu-ml/SageAttention" "description": "[DEPRECATED] Use `attn_implementation: sage`. https://github.com/thu-ml/SageAttention"
}, },
) )
eager_attention: bool | None = None eager_attention: bool | None = Field(
default=None,
deprecated="Use `attn_implementation: eager` instead.",
)
attn_implementation: AttnImplementation | str | None = Field( attn_implementation: str | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
"description": "Attention backend: eager, flash, sdpa, xformers, flex, sage, s2, fp8, or a custom string for kernels." "description": (
"Attention backend. Canonical values: eager, sdpa, flash_attention_2, "
"flash_attention_3, flex_attention, xformers, sage, s2, fp8. Hub-kernel "
"paths (e.g. kernels-community/flash-attn3) are also accepted and passed "
"through to transformers."
)
}, },
) )
@@ -1335,29 +1355,24 @@ class AxolotlInputConfig(
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs] return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
return None return None
# --- Attention capability flags (computed by normalize_attn_implementation) --- # --- Attention capability flags (derived from attn_implementation) ---
attn_supports_packing: bool | None = Field( @computed_field # type: ignore[misc]
default=None, @property
json_schema_extra={ def attn_supports_packing(self) -> bool:
"description": "Whether the attention backend supports varlen sample packing. " return self.attn_implementation in ATTN_IMPLS_SUPPORTING_PACKING
"Computed automatically from attn_implementation."
}, @computed_field # type: ignore[misc]
) @property
attn_uses_flash_lib: bool | None = Field( def attn_uses_flash_lib(self) -> bool:
default=None, return self.attn_implementation in ATTN_IMPLS_USING_FLASH_LIB
json_schema_extra={
"description": "Whether the attention backend requires axolotl's flash_attn " @computed_field # type: ignore[misc]
"monkeypatches. Computed automatically from attn_implementation." @property
}, def attn_needs_dtype_cast(self) -> bool:
) if self.attn_implementation is None:
attn_needs_dtype_cast: bool | None = Field( return False
default=None, return self.attn_implementation not in ATTN_IMPLS_WITHOUT_DTYPE_CAST
json_schema_extra={
"description": "Whether the attention backend needs embedding dtype cast to "
"fp16/bf16. Computed automatically from attn_implementation."
},
)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@@ -1382,90 +1397,83 @@ class AxolotlInputConfig(
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def normalize_attn_implementation(cls, data): def normalize_attn_implementation(cls, data):
"""Normalize attention config: map between attn_implementation enum and legacy boolean flags.""" """Map legacy boolean attention flags to the canonical `attn_implementation`.
`attn_implementation` is the single source of truth on the validated
config. Legacy booleans (`flash_attention: true`, …) are input-only
aliases; this validator warns, maps them to their canonical value, and
strips them from `data` so they cannot be read downstream.
Raises if a canonical `attn_implementation` is set alongside any legacy
boolean — users must pick one.
"""
if not isinstance(data, dict):
return data
attn_impl = data.get("attn_implementation") attn_impl = data.get("attn_implementation")
set_flags = [f for f in LEGACY_ATTN_FLAG_TO_IMPL if data.get(f)]
# If gemma4_hybrid_attn_impl is set but no attn_implementation, default # gemma4_hybrid defaults to flash_attention_2 when user didn't pick a
# to flash (the sliding-window layers use FA2, and packing should be enabled). # backend. The sliding-window layers run under FA2; post-load patching
if data.get("gemma4_hybrid_attn_impl") and not attn_impl: # swaps global layers to sdpa (see `_apply_gemma_hybrid_attention`).
data["attn_implementation"] = "flash" if data.get("gemma4_hybrid_attn_impl") and not attn_impl and not set_flags:
attn_impl = "flash" data["attn_implementation"] = "flash_attention_2"
attn_impl = "flash_attention_2"
# Mapping: attn_implementation value -> primary legacy flag to set
impl_to_flag = {
"eager": "eager_attention",
"flash": "flash_attention",
"sdpa": "sdp_attention",
"xformers": "xformers_attention",
"flex": "flex_attention",
"sage": "sage_attention",
"s2": "s2_attention",
"fp8": None, # new, no legacy flag
}
# Reverse mapping: legacy flag -> attn_implementation value
flag_to_impl = {
"eager_attention": "eager",
"flash_attention": "flash",
"sdp_attention": "sdpa",
"xformers_attention": "xformers",
"flex_attention": "flex",
"sage_attention": "sage",
"s2_attention": "s2",
}
# Find which legacy flags are set
set_flags = [f for f, impl in flag_to_impl.items() if data.get(f)]
if attn_impl and set_flags: if attn_impl and set_flags:
# Both set — check consistency raise ValueError(
expected_flag = impl_to_flag.get(attn_impl) f"attn_implementation={attn_impl!r} cannot be combined with legacy "
for flag in set_flags: f"attention flags ({', '.join(sorted(set_flags))}). The legacy "
if flag != expected_flag: f"flags are deprecated — set only `attn_implementation`."
raise ValueError( )
f"attn_implementation={attn_impl!r} conflicts with {flag}=true. "
f"Use only attn_implementation or the legacy flag, not both." if not attn_impl and set_flags:
) # Priority: specific backends beat generic flash/sdp/eager fallbacks.
elif attn_impl and not set_flags: for flag in LEGACY_ATTN_FLAG_TO_IMPL:
# attn_implementation set, no legacy flags — set primary for backwards compat
flag = impl_to_flag.get(attn_impl)
if flag:
data[flag] = True
elif not attn_impl and set_flags:
# Legacy flags set, no attn_implementation — map to enum, warn
# Priority: specific backends first, then generic flash/sdp/eager
priority = [
"xformers_attention",
"s2_attention",
"sage_attention",
"flex_attention",
"flash_attention",
"sdp_attention",
"eager_attention",
]
for flag in priority:
if flag in set_flags: if flag in set_flags:
data["attn_implementation"] = flag_to_impl[flag] canonical = LEGACY_ATTN_FLAG_TO_IMPL[flag]
data["attn_implementation"] = canonical
LOG.warning( LOG.warning(
"`%s: true` is deprecated. Use `attn_implementation: %s` instead.", "`%s: true` is deprecated and will be removed in a future "
"release. Use `attn_implementation: %s` instead.",
flag, flag,
flag_to_impl[flag], canonical,
) )
break break
# Compute capability flags from the final attn_implementation value # Strip legacy flags from validated data — canonical field is authoritative.
impl = data.get("attn_implementation") for flag in LEGACY_ATTN_FLAG_TO_IMPL:
if impl: data.pop(flag, None)
data["attn_supports_packing"] = impl not in _NON_PACKING_ATTN_IMPLS
data["attn_uses_flash_lib"] = impl in FLASH_ATTN_LIB_IMPLS
data["attn_needs_dtype_cast"] = impl not in _NO_DTYPE_CAST_ATTN_IMPLS
else:
data["attn_supports_packing"] = False
data["attn_uses_flash_lib"] = False
data["attn_needs_dtype_cast"] = False
return data return data
@field_validator("attn_implementation", mode="before")
@classmethod
def validate_attn_implementation(cls, value):
"""Accept canonical names and hub-kernel paths; reject short-form aliases."""
if value is None:
return None
if not isinstance(value, str):
raise TypeError(
f"attn_implementation must be a string, got {type(value).__name__}"
)
if value in CANONICAL_ATTN_IMPLS:
return value
if "/" in value:
# Hub-kernel path, e.g. "kernels-community/flash-attn3". Pass through.
return value
if value in SHORT_FORM_ALIAS_TO_CANONICAL:
canonical = SHORT_FORM_ALIAS_TO_CANONICAL[value]
raise ValueError(
f"attn_implementation={value!r} is not accepted. "
f"Use the canonical name {canonical!r} instead."
)
raise ValueError(
f"attn_implementation={value!r} is not a recognized backend. "
f"Expected one of: {sorted(CANONICAL_ATTN_IMPLS)}, or a hub-kernel "
f"path containing '/'."
)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_sageattn_wo_sample_packing(cls, data): def check_sageattn_wo_sample_packing(cls, data):
@@ -1763,7 +1771,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_flex_torch_version(cls, data): def check_flex_torch_version(cls, data):
if data.get("flex_attention") or data.get("attn_implementation") == "flex": if (
data.get("flex_attention")
or data.get("attn_implementation") == "flex_attention"
):
env_capabilities = data.get("env_capabilities", {}) env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version") torch_version = env_capabilities.get("torch_version")

View File

@@ -97,30 +97,75 @@ class CustomSupportedOptimizers(str, Enum):
flash_lion = "flash_lion" flash_lion = "flash_lion"
class AttnImplementation(str, Enum): # Canonical values accepted for `attn_implementation`. These are passed to HF
"""Attention backend implementations""" # verbatim via `model.config._attn_implementation`. HF-native backends use HF's
# own names (`flash_attention_2`, `flex_attention`, ...); axolotl-owned backends
# (`xformers`, `sage`, `s2`, `fp8`) register into `ALL_ATTENTION_FUNCTIONS` under
# these exact names. Hub-kernel paths (e.g. `kernels-community/flash-attn3`) are
# not in this set — they pass through the validator via the "/" check.
CANONICAL_ATTN_IMPLS = frozenset(
{
"eager",
"sdpa",
"flash_attention_2",
"flash_attention_3",
"flex_attention",
"xformers",
"sage",
"s2",
"fp8",
}
)
eager = "eager" # pylint: disable=invalid-name # Legacy boolean attention flags → canonical `attn_implementation`. Kept for
flash = "flash" # pylint: disable=invalid-name # backwards compatibility; the normalizer warns and strips these from the
sdpa = "sdpa" # pylint: disable=invalid-name # validated config. Priority order (first match wins) matches the old priority:
xformers = "xformers" # pylint: disable=invalid-name # specific backends beat the generic flash/sdp/eager fallbacks.
flex = "flex" # pylint: disable=invalid-name LEGACY_ATTN_FLAG_TO_IMPL = {
sage = "sage" # pylint: disable=invalid-name "xformers_attention": "xformers",
s2 = "s2" # pylint: disable=invalid-name "s2_attention": "s2",
fp8 = "fp8" # pylint: disable=invalid-name "sage_attention": "sage",
"flex_attention": "flex_attention",
"flash_attention": "flash_attention_2",
"sdp_attention": "sdpa",
"eager_attention": "eager",
}
# Short-form aliases that were accepted by the in-progress branch but are
# rejected going forward. Mapped to canonical names only to produce a helpful
# error message pointing users at the right value.
SHORT_FORM_ALIAS_TO_CANONICAL = {
"flash": "flash_attention_2",
"flex": "flex_attention",
"sdp": "sdpa",
}
# Backends that require the flash_attn library (Dao-AILab/flash-attention) # Backends that support varlen sample packing via `position_ids`.
# for axolotl's own monkeypatches (FA4 auto-apply, LLaMA flash hijack, etc.) ATTN_IMPLS_SUPPORTING_PACKING = frozenset(
FLASH_ATTN_LIB_IMPLS = frozenset({"flash", "s2"}) {
"flash_attention_2",
"flash_attention_3",
"flex_attention",
"xformers",
"sage",
"kernels-community/flash-attn3",
"kernels-community/sage-attention",
}
)
# Known backends that do NOT support varlen sample packing via position_ids. # Backends that require the flash_attn library (Dao-AILab/flash-attention) for
# Used as an exclusion list: unknown strings (e.g., HF hub kernels like # axolotl's own monkeypatches (FA4 auto-apply, LLaMA flash hijack, ring-FA, ...).
# "kernels-community/flash-attn3") default to packing-capable. ATTN_IMPLS_USING_FLASH_LIB = frozenset(
_NON_PACKING_ATTN_IMPLS = frozenset({"eager", "sdpa", "s2", "fp8"}) {
"flash_attention_2",
"flash_attention_3",
"s2",
"kernels-community/flash-attn3",
}
)
# Known backends that do NOT need embedding dtype cast. # Backends for which embeddings stay in fp32. Everything else needs fp16/bf16.
_NO_DTYPE_CAST_ATTN_IMPLS = frozenset({"eager", "sdpa"}) ATTN_IMPLS_WITHOUT_DTYPE_CAST = frozenset({"eager", "sdpa"})
class RingAttnFunc(str, Enum): class RingAttnFunc(str, Enum):

View File

@@ -13,7 +13,7 @@ from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import ( from axolotl.utils.schemas.enums import (
_NON_PACKING_ATTN_IMPLS, ATTN_IMPLS_SUPPORTING_PACKING,
ChatTemplate, ChatTemplate,
RingAttnFunc, RingAttnFunc,
RLType, RLType,
@@ -184,26 +184,8 @@ class DatasetValidationMixin:
class AttentionValidationMixin: class AttentionValidationMixin:
"""Validation methods related to attention mechanisms.""" """Validation methods related to attention mechanisms."""
@model_validator(mode="before") # `check_attention_fields` was removed — `AxolotlInputConfig.normalize_attn_implementation`
@classmethod # is now the single entry point for attention-input mapping and conflict detection.
def check_attention_fields(cls, data):
# If attn_implementation is set, the enum handles mutual exclusivity.
# This validator catches legacy configs with multiple boolean flags.
if data.get("attn_implementation"):
return data
fields = (
"xformers_attention",
"sdp_attention",
# "s2_attention", # requires both FA and this to be enabled
"flash_attention",
"flex_attention",
"sage_attention",
)
non_empty_count = sum(1 for field in fields if data.get(field))
if non_empty_count > 1:
raise ValueError(f"Only one of {', '.join(fields)} must be set")
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@@ -238,7 +220,8 @@ class AttentionValidationMixin:
@classmethod @classmethod
def check_scaling_softmax_requires_flex(cls, data): def check_scaling_softmax_requires_flex(cls, data):
if data.get("scaling_softmax") and not ( if data.get("scaling_softmax") and not (
data.get("flex_attention") or data.get("attn_implementation") == "flex" data.get("flex_attention")
or data.get("attn_implementation") == "flex_attention"
): ):
raise ValueError( raise ValueError(
"scaling_softmax requires flex attention.\n" "scaling_softmax requires flex attention.\n"
@@ -956,7 +939,7 @@ class OptimizationValidationMixin:
if data.get("batch_flattening"): if data.get("batch_flattening"):
batch_flattening_auto = data.get("batch_flattening") == "auto" batch_flattening_auto = data.get("batch_flattening") == "auto"
has_varlen_attn = ( has_varlen_attn = (
data.get("attn_implementation") not in _NON_PACKING_ATTN_IMPLS data.get("attn_implementation") in ATTN_IMPLS_SUPPORTING_PACKING
if data.get("attn_implementation") if data.get("attn_implementation")
else data.get("flash_attention") else data.get("flash_attention")
) )
@@ -1683,7 +1666,8 @@ class EBFTValidationMixin:
data.get("rl") == "ebft" data.get("rl") == "ebft"
and data.get("ebft", {}).get("mode") == "strided" and data.get("ebft", {}).get("mode") == "strided"
and ( and (
data.get("flex_attention") or data.get("attn_implementation") == "flex" data.get("flex_attention")
or data.get("attn_implementation") == "flex_attention"
) )
and data.get("gradient_checkpointing") and data.get("gradient_checkpointing")
): ):

View File

@@ -1,272 +1,205 @@
""" """Tests for attn_implementation normalization and capability computation.
Tests for attn_implementation normalization, registry registration,
capability properties, and backwards compatibility with legacy boolean Covers the Phase 1 contract:
attention flags. - `attn_implementation` accepts canonical names only; short-form aliases are rejected.
- Legacy boolean flags are mapped to the canonical value, warned on, and stripped.
- Canonical `attn_implementation` + legacy flag raises.
- Capability flags are computed from `attn_implementation`.
""" """
import pytest import pytest
from axolotl.utils.schemas.config import AxolotlInputConfig from axolotl.utils.schemas.config import AxolotlInputConfig
from axolotl.utils.schemas.enums import ( from axolotl.utils.schemas.enums import (
_NO_DTYPE_CAST_ATTN_IMPLS, ATTN_IMPLS_SUPPORTING_PACKING,
_NON_PACKING_ATTN_IMPLS, ATTN_IMPLS_USING_FLASH_LIB,
FLASH_ATTN_LIB_IMPLS, ATTN_IMPLS_WITHOUT_DTYPE_CAST,
CANONICAL_ATTN_IMPLS,
) )
class TestAttnImplementationNormalizer: class TestNormalizerLegacyMapping:
"""Test the normalize_attn_implementation validator.""" """Legacy boolean flags map to canonical attn_implementation."""
@staticmethod @staticmethod
def _normalize(data): def _normalize(data):
return AxolotlInputConfig.normalize_attn_implementation(data) return AxolotlInputConfig.normalize_attn_implementation(data)
# --- Forward mapping: attn_implementation -> legacy flags ---
@pytest.mark.parametrize( @pytest.mark.parametrize(
"impl,expected_flag", "flag,expected",
[ [
("eager", "eager_attention"), ("flash_attention", "flash_attention_2"),
("flash", "flash_attention"),
("sdpa", "sdp_attention"),
("flex", "flex_attention"),
("xformers", "xformers_attention"),
("sage", "sage_attention"),
("s2", "s2_attention"),
],
)
def test_attn_impl_sets_primary_legacy_flag(self, impl, expected_flag):
data = {"attn_implementation": impl}
result = AxolotlInputConfig.normalize_attn_implementation(data)
assert result.get(expected_flag) is True, (
f"{impl}: expected {expected_flag}=True"
)
@pytest.mark.parametrize("impl", ["xformers", "sage", "s2"])
def test_attn_impl_does_not_set_flash_for_non_flash(self, impl):
"""xformers, sage, s2 should NOT set flash_attention=True anymore."""
result = self._normalize({"attn_implementation": impl})
assert not result.get("flash_attention"), (
f"{impl} should not set flash_attention"
)
def test_fp8_sets_no_legacy_flags(self):
result = self._normalize({"attn_implementation": "fp8"})
for flag in [
"flash_attention",
"sdp_attention",
"eager_attention",
"xformers_attention",
"sage_attention",
"flex_attention",
"s2_attention",
]:
assert not result.get(flag), f"fp8 should not set {flag}"
# --- Reverse mapping: legacy flags -> attn_implementation ---
@pytest.mark.parametrize(
"flag,expected_impl",
[
("flash_attention", "flash"),
("sdp_attention", "sdpa"), ("sdp_attention", "sdpa"),
("xformers_attention", "xformers"), ("xformers_attention", "xformers"),
("flex_attention", "flex"), ("flex_attention", "flex_attention"),
("sage_attention", "sage"), ("sage_attention", "sage"),
("eager_attention", "eager"), ("eager_attention", "eager"),
("s2_attention", "s2"), ("s2_attention", "s2"),
], ],
) )
def test_legacy_flag_sets_attn_impl(self, flag, expected_impl): def test_legacy_flag_maps_to_canonical(self, flag, expected):
result = self._normalize({flag: True}) result = self._normalize({flag: True})
assert result["attn_implementation"] == expected_impl assert result["attn_implementation"] == expected
# --- Priority: s2/sage should win over flash when both set --- def test_legacy_flags_are_stripped_after_mapping(self):
result = self._normalize({"flash_attention": True})
def test_s2_plus_flash_maps_to_s2(self):
"""Legacy configs often have both s2_attention and flash_attention."""
result = self._normalize({"s2_attention": True, "flash_attention": True})
assert result["attn_implementation"] == "s2"
def test_sage_plus_flash_maps_to_sage(self):
"""sage_attention should take priority over flash_attention."""
result = self._normalize({"sage_attention": True, "flash_attention": True})
assert result["attn_implementation"] == "sage"
# --- Consistency: both set, matching ---
def test_consistent_both_set_no_error(self):
result = self._normalize(
{"attn_implementation": "flash", "flash_attention": True}
)
assert result["attn_implementation"] == "flash"
assert result["flash_attention"] is True
def test_consistent_xformers_with_own_flag(self):
"""xformers + xformers_attention should be OK."""
result = self._normalize(
{"attn_implementation": "xformers", "xformers_attention": True}
)
assert result["attn_implementation"] == "xformers"
# --- Conflict detection ---
def test_conflicting_impl_and_flag_raises(self):
with pytest.raises(ValueError, match="conflicts with"):
self._normalize({"attn_implementation": "flash", "sdp_attention": True})
def test_conflicting_xformers_impl_with_sdp_flag(self):
with pytest.raises(ValueError, match="conflicts with"):
self._normalize({"attn_implementation": "xformers", "sdp_attention": True})
def test_xformers_with_flash_flag_conflicts(self):
"""After normalizer change, xformers no longer expects flash_attention."""
with pytest.raises(ValueError, match="conflicts with"):
self._normalize(
{
"attn_implementation": "xformers",
"xformers_attention": True,
"flash_attention": True,
}
)
def test_s2_with_flash_flag_conflicts(self):
"""After normalizer change, s2 no longer expects flash_attention."""
with pytest.raises(ValueError, match="conflicts with"):
self._normalize(
{
"attn_implementation": "s2",
"s2_attention": True,
"flash_attention": True,
}
)
# --- Hub kernel strings pass through ---
def test_hub_kernel_passthrough(self):
result = self._normalize(
{"attn_implementation": "kernels-community/flash-attn3"}
)
assert result["attn_implementation"] == "kernels-community/flash-attn3"
# Should not set any legacy flags
for flag in [ for flag in [
"flash_attention", "flash_attention",
"sdp_attention", "sdp_attention",
"eager_attention",
"xformers_attention", "xformers_attention",
"flex_attention",
"sage_attention",
"eager_attention",
"s2_attention",
]: ]:
assert not result.get(flag) assert flag not in result
def test_custom_string_passthrough(self): def test_s2_plus_flash_priority_is_s2(self):
result = self._normalize({"attn_implementation": "my_custom_kernel"}) result = self._normalize({"s2_attention": True, "flash_attention": True})
assert result["attn_implementation"] == "my_custom_kernel" assert result["attn_implementation"] == "s2"
# --- No attention set --- def test_sage_plus_flash_priority_is_sage(self):
result = self._normalize({"sage_attention": True, "flash_attention": True})
assert result["attn_implementation"] == "sage"
class TestNormalizerConflicts:
"""Canonical attn_implementation + legacy flag raises."""
@staticmethod
def _normalize(data):
return AxolotlInputConfig.normalize_attn_implementation(data)
def test_canonical_plus_legacy_flag_raises(self):
with pytest.raises(ValueError, match="cannot be combined with legacy"):
self._normalize(
{"attn_implementation": "flash_attention_2", "flash_attention": True}
)
def test_canonical_plus_unrelated_legacy_flag_raises(self):
with pytest.raises(ValueError, match="cannot be combined with legacy"):
self._normalize(
{"attn_implementation": "xformers", "flash_attention": True}
)
class TestNormalizerPassthrough:
"""Canonical values and hub-kernel paths pass through."""
def test_canonical_no_legacy_is_noop(self):
data = {"attn_implementation": "flash_attention_2"}
result = AxolotlInputConfig.normalize_attn_implementation(data)
assert result["attn_implementation"] == "flash_attention_2"
def test_hub_kernel_passes_through(self):
data = {"attn_implementation": "kernels-community/flash-attn3"}
result = AxolotlInputConfig.normalize_attn_implementation(data)
assert result["attn_implementation"] == "kernels-community/flash-attn3"
def test_no_attention_set_is_noop(self): def test_no_attention_set_is_noop(self):
result = self._normalize({"some_other_config": True}) result = AxolotlInputConfig.normalize_attn_implementation(
{"some_other_config": True}
)
assert result.get("attn_implementation") is None assert result.get("attn_implementation") is None
# --- Gemma4 hybrid ---
def test_gemma4_hybrid_sets_flash(self): class TestGemma4Hybrid:
"""gemma4_hybrid_attn_impl should default attn_implementation to flash.""" """gemma4_hybrid_attn_impl defaults to flash_attention_2."""
result = self._normalize({"gemma4_hybrid_attn_impl": True})
assert result["attn_implementation"] == "flash"
assert result["flash_attention"] is True
def test_gemma4_hybrid_does_not_override_explicit(self): def test_gemma4_hybrid_defaults_to_fa2(self):
"""If attn_implementation is already set, gemma4 should not override it.""" result = AxolotlInputConfig.normalize_attn_implementation(
result = self._normalize( {"gemma4_hybrid_attn_impl": True}
)
assert result["attn_implementation"] == "flash_attention_2"
def test_gemma4_hybrid_respects_explicit(self):
result = AxolotlInputConfig.normalize_attn_implementation(
{"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"} {"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"}
) )
assert result["attn_implementation"] == "sdpa" assert result["attn_implementation"] == "sdpa"
class TestAttnCapabilityProperties: class TestFieldValidator:
"""Test the capability properties on the normalizer data. """attn_implementation field_validator rejects short-form aliases."""
Since these are @property on AxolotlInputConfig (a Pydantic model), def test_canonical_accepted(self):
we test the underlying logic directly using the constant sets. for impl in CANONICAL_ATTN_IMPLS:
""" assert AxolotlInputConfig.validate_attn_implementation(impl) == impl
# --- attn_supports_packing --- def test_hub_kernel_accepted(self):
for impl in (
"kernels-community/flash-attn3",
"kernels-community/sage-attention",
"someorg/custom-kernel",
):
assert AxolotlInputConfig.validate_attn_implementation(impl) == impl
@pytest.mark.parametrize("impl", ["flash", "flex", "xformers", "sage"]) def test_none_accepted(self):
def test_supports_packing_true(self, impl): assert AxolotlInputConfig.validate_attn_implementation(None) is None
assert impl not in _NON_PACKING_ATTN_IMPLS
@pytest.mark.parametrize("impl", ["eager", "sdpa", "s2", "fp8"]) @pytest.mark.parametrize("alias", ["flash", "flex", "sdp"])
def test_supports_packing_false(self, impl): def test_short_form_alias_rejected(self, alias):
assert impl in _NON_PACKING_ATTN_IMPLS with pytest.raises(ValueError, match="is not accepted"):
AxolotlInputConfig.validate_attn_implementation(alias)
def test_hub_kernel_supports_packing(self): def test_unknown_without_slash_rejected(self):
"""Unknown hub kernels should default to packing-capable.""" with pytest.raises(ValueError, match="not a recognized backend"):
assert "kernels-community/flash-attn3" not in _NON_PACKING_ATTN_IMPLS AxolotlInputConfig.validate_attn_implementation("not_a_real_backend")
# --- attn_uses_flash_lib ---
@pytest.mark.parametrize("impl", ["flash", "s2"]) class TestCapabilityTables:
def test_uses_flash_lib_true(self, impl): """Capability tables are keyed by canonical names and cover the expected backends."""
assert impl in FLASH_ATTN_LIB_IMPLS
@pytest.mark.parametrize( @pytest.mark.parametrize(
"impl", ["eager", "sdpa", "xformers", "flex", "sage", "fp8"] "impl",
[
"flash_attention_2",
"flash_attention_3",
"flex_attention",
"xformers",
"sage",
],
) )
def test_uses_flash_lib_false(self, impl): def test_supports_packing(self, impl):
assert impl not in FLASH_ATTN_LIB_IMPLS assert impl in ATTN_IMPLS_SUPPORTING_PACKING
def test_hub_kernel_not_flash_lib(self): @pytest.mark.parametrize("impl", ["eager", "sdpa", "s2", "fp8"])
"""Hub kernels are HF-managed, not axolotl monkeypatch targets.""" def test_does_not_support_packing(self, impl):
assert "kernels-community/flash-attn3" not in FLASH_ATTN_LIB_IMPLS assert impl not in ATTN_IMPLS_SUPPORTING_PACKING
# --- attn_needs_dtype_cast --- @pytest.mark.parametrize("impl", ["flash_attention_2", "flash_attention_3", "s2"])
def test_uses_flash_lib(self, impl):
assert impl in ATTN_IMPLS_USING_FLASH_LIB
@pytest.mark.parametrize(
"impl", ["eager", "sdpa", "xformers", "flex_attention", "sage", "fp8"]
)
def test_does_not_use_flash_lib(self, impl):
assert impl not in ATTN_IMPLS_USING_FLASH_LIB
@pytest.mark.parametrize("impl", ["eager", "sdpa"]) @pytest.mark.parametrize("impl", ["eager", "sdpa"])
def test_no_dtype_cast(self, impl): def test_no_dtype_cast(self, impl):
assert impl in _NO_DTYPE_CAST_ATTN_IMPLS assert impl in ATTN_IMPLS_WITHOUT_DTYPE_CAST
@pytest.mark.parametrize("impl", ["flash", "flex", "sage", "xformers", "s2", "fp8"])
def test_needs_dtype_cast(self, impl):
assert impl not in _NO_DTYPE_CAST_ATTN_IMPLS
class TestAttnImplToHFMapping:
"""Test that attn_implementation enum values map correctly to HF strings."""
# This dict mirrors _ATTN_IMPL_TO_HF in model.py
_ATTN_IMPL_TO_HF = {
"eager": "eager",
"flash": "flash_attention_2",
"sdpa": "sdpa",
"xformers": "xformers",
"flex": "flex_attention",
"sage": "sage",
"s2": "flash_attention_2",
"fp8": "sdpa",
}
@pytest.mark.parametrize( @pytest.mark.parametrize(
"impl,expected_hf", "impl",
[ [
("eager", "eager"), "flash_attention_2",
("flash", "flash_attention_2"), "flash_attention_3",
("sdpa", "sdpa"), "flex_attention",
("xformers", "xformers"), "xformers",
("flex", "flex_attention"), "sage",
("sage", "sage"), "s2",
("s2", "flash_attention_2"), "fp8",
("fp8", "sdpa"),
], ],
) )
def test_known_impl_maps_correctly(self, impl, expected_hf): def test_needs_dtype_cast(self, impl):
assert self._ATTN_IMPL_TO_HF[impl] == expected_hf assert impl not in ATTN_IMPLS_WITHOUT_DTYPE_CAST
def test_hub_kernel_falls_through(self): def test_known_hub_kernels_classified(self):
"""Hub kernel strings should pass through .get() unchanged.""" assert "kernels-community/flash-attn3" in ATTN_IMPLS_SUPPORTING_PACKING
hub_str = "kernels-community/flash-attn3" assert "kernels-community/flash-attn3" in ATTN_IMPLS_USING_FLASH_LIB
result = self._ATTN_IMPL_TO_HF.get(hub_str, hub_str) assert "kernels-community/sage-attention" in ATTN_IMPLS_SUPPORTING_PACKING
assert result == hub_str
def _xformers_available(): def _xformers_available():
@@ -279,7 +212,7 @@ def _xformers_available():
class TestAttentionRegistration: class TestAttentionRegistration:
"""Test that attention backends register correctly in HF's registries.""" """Axolotl-owned backends register under their canonical names in HF's registries."""
@pytest.mark.skipif(not _xformers_available(), reason="xformers not available") @pytest.mark.skipif(not _xformers_available(), reason="xformers not available")
def test_register_xformers(self): def test_register_xformers(self):
@@ -292,7 +225,6 @@ class TestAttentionRegistration:
assert "xformers" in ALL_ATTENTION_FUNCTIONS assert "xformers" in ALL_ATTENTION_FUNCTIONS
assert "xformers" in ALL_MASK_ATTENTION_FUNCTIONS assert "xformers" in ALL_MASK_ATTENTION_FUNCTIONS
# xformers mask should be the same function as flash_attention_2's mask
assert ( assert (
ALL_MASK_ATTENTION_FUNCTIONS["xformers"] ALL_MASK_ATTENTION_FUNCTIONS["xformers"]
== ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"] == ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
@@ -315,7 +247,6 @@ class TestAttentionRegistration:
@pytest.mark.skipif(not _xformers_available(), reason="xformers not available") @pytest.mark.skipif(not _xformers_available(), reason="xformers not available")
def test_xformers_does_not_overwrite_fa2(self): def test_xformers_does_not_overwrite_fa2(self):
"""Registering xformers should not modify the flash_attention_2 slot."""
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
original_fa2 = ALL_ATTENTION_FUNCTIONS["flash_attention_2"] original_fa2 = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]
@@ -327,7 +258,6 @@ class TestAttentionRegistration:
assert ALL_ATTENTION_FUNCTIONS["flash_attention_2"] is original_fa2 assert ALL_ATTENTION_FUNCTIONS["flash_attention_2"] is original_fa2
def test_sage_does_not_overwrite_fa2(self): def test_sage_does_not_overwrite_fa2(self):
"""Registering sage should not modify the flash_attention_2 slot."""
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
original_fa2 = ALL_ATTENTION_FUNCTIONS["flash_attention_2"] original_fa2 = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]