make attn_implementation the single source of truth

This commit is contained in:
Wing Lian
2026-04-23 21:17:10 +00:00
parent 35d43fe141
commit 2579c496d5
10 changed files with 491 additions and 387 deletions

142
ATTN_REFACTOR_REVIEW.md Normal file
View File

@@ -0,0 +1,142 @@
# `attn-implementation-refactor` branch review
Review target: `attn-implementation-refactor` (5 commits ahead of main, merge base `69904781`).
Scope: 16 files, +682 / 96.
## 1. What the branch is trying to do
Collapse seven boolean attention flags (`flash_attention`, `sdp_attention`, `xformers_attention`, `flex_attention`, `sage_attention`, `s2_attention`, `eager_attention`) into a single `attn_implementation` field, with derived capability flags (`attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast`) for the gates that used to be ad-hoc OR-chains.
Mechanism: `normalize_attn_implementation` (a `@model_validator(mode="before")` on `AxolotlInputConfig`) maps bidirectionally between the new field and the legacy flags, with a priority list for legacy combos (`s2 + flash → s2`), and then computes the three capability flags from frozen sets in `enums.py`.
Adjacent changes: `xformers` and `sage` now register as their own entries in `ALL_ATTENTION_FUNCTIONS` (with FA2 mask behavior) instead of stomping the `flash_attention_2` slot. New `fp8` backend wires `torchao.prototype.attention.apply_low_precision_attention` in `apply_post_model_load_patches`.
## 2. Target design
**`cfg.attn_implementation` is the single source of truth on the validated config.**
- Its type is `str | None`. Accepted values are **canonical names only** — no short-form aliases:
- HF-native: `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention`. (`flash_attention_3` is net-new to axolotl — the current branch only encodes `flash_attention_2` under the short name `flash`.)
- Axolotl-owned (registered into `ALL_ATTENTION_FUNCTIONS` under exactly these names): `xformers`, `sage`, `s2`, `fp8`.
- Hub-kernel paths: `kernels-community/sage-attention`, `kernels-community/flash-attn3`, etc. — passthrough. Known-kernel allowlist in `enums.py` classifies the common ones into the capability tables.
Short forms like `flash`, `fa2`, `fa3`, `sdp`, `flex` are rejected (Pydantic validation error with a pointer to the canonical name).
- `model.py:_set_attention_config` passes `cfg.attn_implementation` to HF verbatim — no `_ATTN_IMPL_TO_HF` translation dict needed.
- Legacy booleans (`flash_attention: true`, `sdp_attention: true`, …) are the **only** input aliases, kept for backwards compatibility. The normalizer maps them to the canonical `attn_implementation` value, emits a one-time `DeprecationWarning` per flag, and removes them from `data` so they're never readable on the validated `cfg`. `deprecated=True` on each Field surfaces this in JSON schema. Mapping is 1:1 with the current legacy-flag semantics (`flash_attention → flash_attention_2`, `sdp_attention → sdpa`, `flex_attention → flex_attention`, `xformers_attention → xformers`, `sage_attention → sage`, `s2_attention → s2`, `eager_attention → eager`).
- Capability flags (`attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast`) are **`@computed_field` on the model**, not settable inputs. Lookup is keyed by the canonical `attn_implementation` string.
- Unknown / user-supplied strings (custom hub kernels) pass through to HF but get **conservative capability defaults** (packing=False, flash-lib=False, dtype-cast=True). Known hub kernels axolotl can classify live in a small allowlist.
- Downstream consumers read *only* `cfg.attn_implementation` and the capability flags. No `cfg.flash_attention`, `cfg.xformers_attention`, etc. anywhere in `src/`.
This is strictly what the branch is already *trying* to do — the gaps below are places it hasn't landed that goal yet.
## 3. Gaps and holes
### A. Legacy flags are still parallel state, not input-only
1. The normalizer *sets* the legacy flags on `data` (`impl_to_flag[attn_impl]` branch). It does not delete them. So `cfg.flash_attention` is still truthy after validation, and downstream code still reads it (see G).
2. Short-form enum values (`flash`, `sdpa`, `fp8`) are persisted as-is on `cfg.attn_implementation`, which is why `model.py` needs `_ATTN_IMPL_TO_HF` to translate before passing to HF. Source-of-truth implies canonicalize at normalize-time, not translate at consume-time.
3. Legacy flag + `attn_implementation` (consistent combo, e.g. `attn_implementation: flash + flash_attention: true`) emits no deprecation warning — only legacy-only path warns.
4. Legacy Field descriptions (`xformers_attention`, `sdp_attention`, etc.) don't have `deprecated=True`, so JSON schema still advertises them as first-class.
### B. Validators that still only check the legacy flag
5. **`check_ebft_activation_offloading`** (`validation.py:1607-1619`) reads only `data.get("flex_attention")`. Users on `attn_implementation: flex_attention` bypass the incompatibility check.
6. **`check_sample_packing_without_attention`** (`validation.py:188-203`) early-returns when `attn_implementation` is set but never validates the chosen backend actually supports packing. `attn_implementation: eager + sample_packing: true` silently passes; the old legacy-flag check warned.
### C. Non-enum strings fall through the capability tables
7. **HF-native `"flash_attention_2"`** is neither in `impl_to_flag` nor `FLASH_ATTN_LIB_IMPLS`. A user copy-pasting from HF docs gets `attn_uses_flash_lib=False`, silently disabling FA4 auto-apply, LLaMA flash hijack, `_patch_attention` (btlm, stablelm_epoch, mistral3, llava), and `_apply_flash_attention_peft_patches`.
8. **Hub kernel strings** (`kernels-community/flash-attn3`, `kernels-community/sage-attention`) default to `attn_supports_packing=True` (silently enters multipack with varlen `position_ids` — correctness depends on the kernel honoring them) and `attn_uses_flash_lib=False` (so `context_parallel_size > 1` raises "requires flash attention" even for FA3 hub kernels).
9. **Conflict trap for hub-kernel + legacy flag** (`config.py:1414-1419`): `attn_implementation: kernels-community/flash-attn3 + flash_attention: true` always raises, because `impl_to_flag.get(custom) is None` and the loop treats `flag != None` as conflict. Common combo in existing YAMLs breaks hard on upgrade.
### D. Silent behaviour change for xformers
10. Old `_apply_flash_attention_patches` did `self.cfg.flash_attention = True` for `xformers + sample_packing`. The new version doesn't, and xformers is not in `FLASH_ATTN_LIB_IMPLS`. Consumers that keyed off `cfg.flash_attention` now see falsy for xformers, silently dropping `_patch_attention` (btlm / stablelm_epoch+packing / mistral3 / llava model-type FA patches). Some of this is arguably correct cleanup (xformers has its own HF registry entry now), but the btlm/stablelm/mistral3 regression is not called out and not tested. Decide consciously, not by omission.
### E. Capability flags are writable Pydantic fields, not computed
11. `attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast` are declared `bool | None = Field(default=None)` on `AxolotlInputConfig`. YAML is not rejected — a user can set `attn_uses_flash_lib: true` and override the normalizer.
### F. Validator ordering (not covered by tests)
12. `AttentionValidationMixin.check_attention_fields` (inherited, `mode="before"`) and `normalize_attn_implementation` (subclass, `mode="before"`) both run during `model_validator` phase. Pydantic MRO may run the inherited one first. For legacy-only `s2_attention: true + flash_attention: true` (the test `test_s2_plus_flash_maps_to_s2` asserts this maps to `s2`), the inherited check may raise "multiple attention implementations set" before the normalizer runs. The test calls the classmethod directly and does not build the model, so this is unverified either way.
### G. Remaining legacy reads in `src/`
13. `src/axolotl/integrations/lm_eval/cli.py:120` reads `cfg.flash_attention`. Works for `attn_implementation=flash` only.
14. `tests/e2e/multigpu/test_llama.py:524-526` writes `cfg.flash_attention = True` / `cfg.flex_attention = True`. Stale pattern.
15. Dual-check idioms in `config.py` (lines 1464, 1478, 1570, 1586, 1774) and `validation.py` (198, 209, 221, 850, 1586, 1611) — `data.get("x_attention") or data.get("attn_implementation") == "x"` — are redundant once legacy flags are input-only; remove them.
### H. fp8 operational risk
16. The `fp8` docstring documents hard requirements (PyTorch ≥ 2.11, SM90+, flash-attn with FA3, torchao ≥ 0.17.0) and a runtime constraint (`config.use_cache = False`). None are validated — misconfig surfaces as a torchao runtime error. `xformers` and `sage` availability/compute-capability guards exist; `fp8` should match.
### I. Test coverage gaps
17. `test_attn_implementation.py` exercises the classmethod in isolation plus the constant sets. It does **not**:
- Build a full `AxolotlInputConfig(**data)` (so validator ordering — item 12 — is untested).
- Verify capability flags can't be overridden from YAML (item 11).
- Cover `check_sample_packing_without_attention` with `attn_implementation: eager` (item 6).
- Cover `check_ebft_activation_offloading` with `attn_implementation: flex_attention` (item 5).
- Cover hub-kernel + legacy flag combo (item 9).
- Cover `flash_attention_2` canonicalization (item 7).
## 4. Fix plan
Four phases, each commit-sized. Phases 12 are local and low-risk; phase 3 is the behaviour-changing cleanup; phase 4 is tests + docs.
### Phase 1 — Lock down the data model
1. Drop the `AttnImplementation` enum. `attn_implementation` becomes `str | None`, validated against a canonical allowlist (`eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention`, `xformers`, `sage`, `s2`, `fp8`) **or** a hub-kernel path (`startswith("kernels-")` / contains `/`). Reject short-form strings like `flash` / `fa2` / `sdp` / `flex` with an explicit error pointing at the canonical name.
2. Rewrite `normalize_attn_implementation` so its only job is mapping **legacy booleans → canonical `attn_implementation`** (for BC). Mapping is fixed:
- `flash_attention → flash_attention_2`
- `sdp_attention → sdpa`
- `flex_attention → flex_attention`
- `xformers_attention → xformers`
- `sage_attention → sage`
- `s2_attention → s2`
- `eager_attention → eager`
Priority for legacy combos stays as in the current branch (`s2 > sage > xformers > flex > flash > sdp > eager`). Emit a one-time `DeprecationWarning` per unique legacy flag seen. After mapping, delete the legacy flag keys from `data` so they never appear on validated `cfg`. If both a canonical `attn_implementation` and any legacy flag are set, raise (no silent precedence).
**Merge `AttentionValidationMixin.check_attention_fields` into this normalizer and delete the mixin method.** Pydantic v2 runs inherited `mode="before"` validators before subclass ones per MRO, so leaving them as siblings causes the inherited check to reject legacy combos like `s2 + flash` before the normalizer can map them. One validator, one source of conflict detection.
**Fix the gemma4-hybrid path**: change `data["attn_implementation"] = "flash"` to `data["attn_implementation"] = "flash_attention_2"` (the short name no longer validates after step 1).
3. Convert `attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast` to `@computed_field`. The three capability tables move to `enums.py` as module constants keyed by the canonical `attn_implementation` string (including `flash_attention_3` — missing from the current branch — and known hub kernels):
- Packing-capable: `{flash_attention_2, flash_attention_3, flex_attention, xformers, sage, kernels-community/flash-attn3, kernels-community/sage-attention}`.
- Flash-lib (axolotl's monkeypatch targets): `{flash_attention_2, flash_attention_3, s2, kernels-community/flash-attn3}`.
- No-dtype-cast: `{eager, sdpa}`.
Unknown strings: conservative defaults (`packing=False, flash_lib=False, dtype_cast=True`).
4. Delete `_ATTN_IMPL_TO_HF` from `model.py` and pass `cfg.attn_implementation` straight through. The gemma4-hybrid branch continues to override to `flash_attention_2` before passing to HF.
5. `deprecated=True` on each legacy boolean Field so JSON schema + Pydantic surface the deprecation.
### Phase 2 — Fix the validators
6. `check_sample_packing_without_attention`: drop the early-return and gate on `attn_supports_packing`. Warn (or raise — pick one and be consistent) if packing is enabled with a non-packing backend.
7. `check_ebft_activation_offloading`: replace `data.get("flex_attention")` with `attn_implementation == "flex_attention"`.
8. Sweep items (item 15): remove every `data.get("x_attention") or data.get("attn_implementation") == "x"` dual-check — after phase 1 the legacy side is always `None`. Reduces ~10 lines of noise and eliminates the "which side wins" class of bugs.
9. fp8 preflight (item 16): require `env_capabilities.compute_capability ≥ sm_90`, `torch_version ≥ 2.11`, and `torchao_version ≥ 0.17`. Warn if `use_cache` isn't explicitly `False`.
### Phase 3 — Migrate remaining consumers
10. `lm_eval/cli.py:120``flash_attention=cfg.attn_uses_flash_lib`.
11. `lm_eval/__init__.py:26` currently reads `(cfg.attn_implementation == "flash")` — after canonicalization `"flash"` is never stored, so this evaluates `False` for every backend. Change to `cfg.attn_uses_flash_lib`.
12. `validation.py:1137-1142` (NPU check) currently iterates `["flash_attention", "sdp_attention", "s2_attention"]` as string keys. Replace with `cfg.attn_implementation in {"flash_attention_2", "flash_attention_3", "sdpa", "s2"}` or the equivalent canonical-string set.
13. `tests/e2e/multigpu/test_llama.py:524-526``cfg.attn_implementation = "flash_attention_2"` / `"flex_attention"`.
14. **Xformers decision** (item 10): the old `cfg.flash_attention = True` side-effect activated `_patch_attention` for btlm/stablelm_epoch+packing/mistral3/llava. Two choices:
- Add `xformers` to the set that gates `_patch_attention` (restore old behaviour, keeps patches live).
- Document that those patches don't apply to xformers post-refactor and drop the paths if they're dead.
Pick one explicitly and leave a commit note. Do not leave it as silent breakage.
15. Add a repo-level check (`tests/test_no_legacy_attn_reads.py` or a ruff/grep pre-commit) that fails if anything outside `config.py`'s normalizer reads `cfg.flash_attention` / `cfg.sdp_attention` / etc. Keeps the invariant from rotting.
### Phase 4 — Tests + docs
14. Rewrite `test_attn_implementation.py` to build full `AxolotlInputConfig(**data)`, not just the classmethod. Covers validator ordering and the Pydantic-field-override issue.
15. Add one test per gap closed above: `attn_implementation: eager + sample_packing`; `attn_implementation: flex_attention + activation_offloading`; short-form `flash` rejected; `flash_attention_2` passthrough; `kernels-community/flash-attn3` capability lookup; `attn_uses_flash_lib: true` in YAML rejected; legacy boolean emits `DeprecationWarning` and is absent from validated `cfg`; fp8 preflight failures.
16. Update `docs/attention.qmd` for the single `attn_implementation` field + the deprecation table for legacy flags. One-paragraph migration note in the changelog.
17. `examples/` contains ~170 YAML files using legacy flags (`flash_attention: true` etc.). They still validate post-refactor (normalizer maps them with deprecation), but a follow-up sweep to convert them to `attn_implementation: flash_attention_2` is worth scheduling — call this out in the migration note so users know examples will be migrated on a later pass.
## 5. Ordering & risk
- Phase 1 is the keystone: it's the largest diff but each step is mechanical once the alias map is in place. No behaviour change for any consumer that was using `attn_implementation` correctly; behaviour change only for consumers that were reading legacy flags (phase 3 step 13 is the explicit decision point).
- Phase 2 is independent of phase 1 and can land first as a quick safety net.
- Phase 3 step 13 is the only judgment call — flag for review before choosing.
- Total: ~10-13 commits beyond what's on the branch, each scoped and individually revertable.

View File

@@ -502,11 +502,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
# supported multipack models, or non-flash-attention llama
if (
self.cfg.attn_implementation == "flex"
self.cfg.attn_implementation == "flex_attention"
or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
or (
self.cfg.model_config_type in ["llama"]
and self.cfg.attn_implementation != "flash"
and self.cfg.attn_implementation != "flash_attention_2"
)
):
collator = V2BatchSamplerDataCollatorForSeq2Seq

View File

@@ -23,7 +23,7 @@ class LMEvalPlugin(BasePlugin):
for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks,
bfloat16=cfg.bfloat16 or cfg.bf16,
flash_attention=(cfg.attn_implementation == "flash"),
flash_attention=cfg.attn_uses_flash_lib,
output_dir=cfg.output_dir,
batch_size=cfg.lm_eval_batch_size,
wandb_project=cfg.wandb_project,

View File

@@ -628,33 +628,25 @@ class ModelLoader:
)
def _set_attention_config(self):
"""Sample packing uses custom FA2 patch"""
# Map attn_implementation enum values to HF attn_implementation strings.
# xformers/sage are registered in ALL_ATTENTION_FUNCTIONS and
# ALL_MASK_ATTENTION_FUNCTIONS under their own names with FA2 mask
# behavior, so they no longer need to masquerade as flash_attention_2.
# s2 still uses flash_attention_2 because it modifies FA2 internals.
# Hub kernel strings (e.g. "kernels-community/flash-attn3") fall
# through the .get() and are passed to HF unchanged.
_ATTN_IMPL_TO_HF = {
"eager": "eager",
"flash": "flash_attention_2",
"sdpa": "sdpa",
"xformers": "xformers",
"flex": "flex_attention",
"sage": "sage",
"s2": "flash_attention_2",
"fp8": "sdpa",
}
# s2 and fp8 need a different HF backend at load time than their
# canonical name: s2 patches FA2 internals, so load under FA2; fp8
# replaces F.scaled_dot_product_attention post-load, so load under sdpa.
# Every other canonical name (and hub-kernel paths) is passed through
# verbatim — xformers/sage/flash_attention_* are registered under their
# own names in ALL_ATTENTION_FUNCTIONS before model load.
_LOAD_TIME_OVERRIDE = {"s2": "flash_attention_2", "fp8": "sdpa"}
if self.cfg.gemma4_hybrid_attn_impl:
# Load model with flash_attention_2 for sliding window layers;
# global layers will be patched to sdpa post-load.
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = "flash_attention_2"
# Load with flash_attention_2 for sliding-window layers; global
# layers are swapped to sdpa post-load.
hf_impl = "flash_attention_2"
elif self.cfg.attn_implementation:
hf_impl = _ATTN_IMPL_TO_HF.get(
hf_impl = _LOAD_TIME_OVERRIDE.get(
self.cfg.attn_implementation, self.cfg.attn_implementation
)
else:
hf_impl = None
if hf_impl is not None:
self.model_kwargs["attn_implementation"] = hf_impl
self.model_config._attn_implementation = hf_impl

View File

@@ -333,7 +333,7 @@ class PatchManager:
def _apply_flex_attention_patches(self):
"""Apply patches for flexible attention."""
if self.cfg.attn_implementation == "flex":
if self.cfg.attn_implementation == "flex_attention":
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_wrapper,
)

View File

@@ -207,7 +207,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
# Mistral's official FA implementation requires left padding
if (
cfg.is_mistral_derived_model
and cfg.attn_implementation == "flash"
and cfg.attn_implementation == "flash_attention_2"
and not cfg.sample_packing
):
tokenizer.padding_side = "left"

View File

@@ -10,7 +10,9 @@ from pydantic import (
BaseModel,
Field,
StringConstraints,
computed_field,
field_serializer,
field_validator,
model_validator,
)
@@ -28,10 +30,12 @@ from axolotl.utils.schemas.datasets import (
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
from axolotl.utils.schemas.enums import (
_NO_DTYPE_CAST_ATTN_IMPLS,
_NON_PACKING_ATTN_IMPLS,
FLASH_ATTN_LIB_IMPLS,
AttnImplementation,
ATTN_IMPLS_SUPPORTING_PACKING,
ATTN_IMPLS_USING_FLASH_LIB,
ATTN_IMPLS_WITHOUT_DTYPE_CAST,
CANONICAL_ATTN_IMPLS,
LEGACY_ATTN_FLAG_TO_IMPL,
SHORT_FORM_ALIAS_TO_CANONICAL,
ChatTemplate,
RingAttnFunc,
RLType,
@@ -739,28 +743,35 @@ class AxolotlInputConfig(
xformers_attention: bool | None = Field(
default=None,
deprecated="Use `attn_implementation: xformers` instead.",
json_schema_extra={
"description": "Whether to use xformers attention patch https://github.com/facebookresearch/xformers"
"description": "[DEPRECATED] Use `attn_implementation: xformers`. https://github.com/facebookresearch/xformers"
},
)
sdp_attention: bool | None = Field(
default=None,
deprecated="Use `attn_implementation: sdpa` instead.",
json_schema_extra={
"description": "Whether to use scaled-dot-product attention https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html"
"description": "[DEPRECATED] Use `attn_implementation: sdpa`."
},
)
s2_attention: bool | None = Field(
default=None,
deprecated="Use `attn_implementation: s2` instead.",
json_schema_extra={
"description": "Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf"
"description": "[DEPRECATED] Use `attn_implementation: s2`. Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf"
},
)
flex_attention: bool | None = None
flex_attention: bool | None = Field(
default=None,
deprecated="Use `attn_implementation: flex_attention` instead.",
)
flex_attn_compile_kwargs: dict[str, Any] | None = None
flash_attention: bool | None = Field(
default=None,
deprecated="Use `attn_implementation: flash_attention_2` instead.",
json_schema_extra={
"description": "Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention"
"description": "[DEPRECATED] Use `attn_implementation: flash_attention_2`. https://github.com/Dao-AILab/flash-attention"
},
)
flash_attn_cross_entropy: bool | None = Field(
@@ -787,17 +798,26 @@ class AxolotlInputConfig(
)
sage_attention: bool | None = Field(
default=None,
deprecated="Use `attn_implementation: sage` instead.",
json_schema_extra={
"description": "Whether to use SageAttention https://github.com/thu-ml/SageAttention"
"description": "[DEPRECATED] Use `attn_implementation: sage`. https://github.com/thu-ml/SageAttention"
},
)
eager_attention: bool | None = None
eager_attention: bool | None = Field(
default=None,
deprecated="Use `attn_implementation: eager` instead.",
)
attn_implementation: AttnImplementation | str | None = Field(
attn_implementation: str | None = Field(
default=None,
json_schema_extra={
"description": "Attention backend: eager, flash, sdpa, xformers, flex, sage, s2, fp8, or a custom string for kernels."
"description": (
"Attention backend. Canonical values: eager, sdpa, flash_attention_2, "
"flash_attention_3, flex_attention, xformers, sage, s2, fp8. Hub-kernel "
"paths (e.g. kernels-community/flash-attn3) are also accepted and passed "
"through to transformers."
)
},
)
@@ -1335,29 +1355,24 @@ class AxolotlInputConfig(
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
return None
# --- Attention capability flags (computed by normalize_attn_implementation) ---
# --- Attention capability flags (derived from attn_implementation) ---
attn_supports_packing: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether the attention backend supports varlen sample packing. "
"Computed automatically from attn_implementation."
},
)
attn_uses_flash_lib: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether the attention backend requires axolotl's flash_attn "
"monkeypatches. Computed automatically from attn_implementation."
},
)
attn_needs_dtype_cast: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether the attention backend needs embedding dtype cast to "
"fp16/bf16. Computed automatically from attn_implementation."
},
)
@computed_field # type: ignore[misc]
@property
def attn_supports_packing(self) -> bool:
return self.attn_implementation in ATTN_IMPLS_SUPPORTING_PACKING
@computed_field # type: ignore[misc]
@property
def attn_uses_flash_lib(self) -> bool:
return self.attn_implementation in ATTN_IMPLS_USING_FLASH_LIB
@computed_field # type: ignore[misc]
@property
def attn_needs_dtype_cast(self) -> bool:
if self.attn_implementation is None:
return False
return self.attn_implementation not in ATTN_IMPLS_WITHOUT_DTYPE_CAST
@model_validator(mode="before")
@classmethod
@@ -1382,90 +1397,83 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def normalize_attn_implementation(cls, data):
"""Normalize attention config: map between attn_implementation enum and legacy boolean flags."""
"""Map legacy boolean attention flags to the canonical `attn_implementation`.
`attn_implementation` is the single source of truth on the validated
config. Legacy booleans (`flash_attention: true`, …) are input-only
aliases; this validator warns, maps them to their canonical value, and
strips them from `data` so they cannot be read downstream.
Raises if a canonical `attn_implementation` is set alongside any legacy
boolean — users must pick one.
"""
if not isinstance(data, dict):
return data
attn_impl = data.get("attn_implementation")
set_flags = [f for f in LEGACY_ATTN_FLAG_TO_IMPL if data.get(f)]
# If gemma4_hybrid_attn_impl is set but no attn_implementation, default
# to flash (the sliding-window layers use FA2, and packing should be enabled).
if data.get("gemma4_hybrid_attn_impl") and not attn_impl:
data["attn_implementation"] = "flash"
attn_impl = "flash"
# Mapping: attn_implementation value -> primary legacy flag to set
impl_to_flag = {
"eager": "eager_attention",
"flash": "flash_attention",
"sdpa": "sdp_attention",
"xformers": "xformers_attention",
"flex": "flex_attention",
"sage": "sage_attention",
"s2": "s2_attention",
"fp8": None, # new, no legacy flag
}
# Reverse mapping: legacy flag -> attn_implementation value
flag_to_impl = {
"eager_attention": "eager",
"flash_attention": "flash",
"sdp_attention": "sdpa",
"xformers_attention": "xformers",
"flex_attention": "flex",
"sage_attention": "sage",
"s2_attention": "s2",
}
# Find which legacy flags are set
set_flags = [f for f, impl in flag_to_impl.items() if data.get(f)]
# gemma4_hybrid defaults to flash_attention_2 when user didn't pick a
# backend. The sliding-window layers run under FA2; post-load patching
# swaps global layers to sdpa (see `_apply_gemma_hybrid_attention`).
if data.get("gemma4_hybrid_attn_impl") and not attn_impl and not set_flags:
data["attn_implementation"] = "flash_attention_2"
attn_impl = "flash_attention_2"
if attn_impl and set_flags:
# Both set — check consistency
expected_flag = impl_to_flag.get(attn_impl)
for flag in set_flags:
if flag != expected_flag:
raise ValueError(
f"attn_implementation={attn_impl!r} conflicts with {flag}=true. "
f"Use only attn_implementation or the legacy flag, not both."
)
elif attn_impl and not set_flags:
# attn_implementation set, no legacy flags — set primary for backwards compat
flag = impl_to_flag.get(attn_impl)
if flag:
data[flag] = True
elif not attn_impl and set_flags:
# Legacy flags set, no attn_implementation — map to enum, warn
# Priority: specific backends first, then generic flash/sdp/eager
priority = [
"xformers_attention",
"s2_attention",
"sage_attention",
"flex_attention",
"flash_attention",
"sdp_attention",
"eager_attention",
]
for flag in priority:
raise ValueError(
f"attn_implementation={attn_impl!r} cannot be combined with legacy "
f"attention flags ({', '.join(sorted(set_flags))}). The legacy "
f"flags are deprecated — set only `attn_implementation`."
)
if not attn_impl and set_flags:
# Priority: specific backends beat generic flash/sdp/eager fallbacks.
for flag in LEGACY_ATTN_FLAG_TO_IMPL:
if flag in set_flags:
data["attn_implementation"] = flag_to_impl[flag]
canonical = LEGACY_ATTN_FLAG_TO_IMPL[flag]
data["attn_implementation"] = canonical
LOG.warning(
"`%s: true` is deprecated. Use `attn_implementation: %s` instead.",
"`%s: true` is deprecated and will be removed in a future "
"release. Use `attn_implementation: %s` instead.",
flag,
flag_to_impl[flag],
canonical,
)
break
# Compute capability flags from the final attn_implementation value
impl = data.get("attn_implementation")
if impl:
data["attn_supports_packing"] = impl not in _NON_PACKING_ATTN_IMPLS
data["attn_uses_flash_lib"] = impl in FLASH_ATTN_LIB_IMPLS
data["attn_needs_dtype_cast"] = impl not in _NO_DTYPE_CAST_ATTN_IMPLS
else:
data["attn_supports_packing"] = False
data["attn_uses_flash_lib"] = False
data["attn_needs_dtype_cast"] = False
# Strip legacy flags from validated data — canonical field is authoritative.
for flag in LEGACY_ATTN_FLAG_TO_IMPL:
data.pop(flag, None)
return data
@field_validator("attn_implementation", mode="before")
@classmethod
def validate_attn_implementation(cls, value):
"""Accept canonical names and hub-kernel paths; reject short-form aliases."""
if value is None:
return None
if not isinstance(value, str):
raise TypeError(
f"attn_implementation must be a string, got {type(value).__name__}"
)
if value in CANONICAL_ATTN_IMPLS:
return value
if "/" in value:
# Hub-kernel path, e.g. "kernels-community/flash-attn3". Pass through.
return value
if value in SHORT_FORM_ALIAS_TO_CANONICAL:
canonical = SHORT_FORM_ALIAS_TO_CANONICAL[value]
raise ValueError(
f"attn_implementation={value!r} is not accepted. "
f"Use the canonical name {canonical!r} instead."
)
raise ValueError(
f"attn_implementation={value!r} is not a recognized backend. "
f"Expected one of: {sorted(CANONICAL_ATTN_IMPLS)}, or a hub-kernel "
f"path containing '/'."
)
@model_validator(mode="before")
@classmethod
def check_sageattn_wo_sample_packing(cls, data):
@@ -1763,7 +1771,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
@model_validator(mode="before")
@classmethod
def check_flex_torch_version(cls, data):
if data.get("flex_attention") or data.get("attn_implementation") == "flex":
if (
data.get("flex_attention")
or data.get("attn_implementation") == "flex_attention"
):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")

View File

@@ -97,30 +97,75 @@ class CustomSupportedOptimizers(str, Enum):
flash_lion = "flash_lion"
class AttnImplementation(str, Enum):
"""Attention backend implementations"""
# Canonical values accepted for `attn_implementation`. These are passed to HF
# verbatim via `model.config._attn_implementation`. HF-native backends use HF's
# own names (`flash_attention_2`, `flex_attention`, ...); axolotl-owned backends
# (`xformers`, `sage`, `s2`, `fp8`) register into `ALL_ATTENTION_FUNCTIONS` under
# these exact names. Hub-kernel paths (e.g. `kernels-community/flash-attn3`) are
# not in this set — they pass through the validator via the "/" check.
CANONICAL_ATTN_IMPLS = frozenset(
{
"eager",
"sdpa",
"flash_attention_2",
"flash_attention_3",
"flex_attention",
"xformers",
"sage",
"s2",
"fp8",
}
)
eager = "eager" # pylint: disable=invalid-name
flash = "flash" # pylint: disable=invalid-name
sdpa = "sdpa" # pylint: disable=invalid-name
xformers = "xformers" # pylint: disable=invalid-name
flex = "flex" # pylint: disable=invalid-name
sage = "sage" # pylint: disable=invalid-name
s2 = "s2" # pylint: disable=invalid-name
fp8 = "fp8" # pylint: disable=invalid-name
# Legacy boolean attention flags → canonical `attn_implementation`. Kept for
# backwards compatibility; the normalizer warns and strips these from the
# validated config. Priority order (first match wins) matches the old priority:
# specific backends beat the generic flash/sdp/eager fallbacks.
LEGACY_ATTN_FLAG_TO_IMPL = {
"xformers_attention": "xformers",
"s2_attention": "s2",
"sage_attention": "sage",
"flex_attention": "flex_attention",
"flash_attention": "flash_attention_2",
"sdp_attention": "sdpa",
"eager_attention": "eager",
}
# Short-form aliases that were accepted by the in-progress branch but are
# rejected going forward. Mapped to canonical names only to produce a helpful
# error message pointing users at the right value.
SHORT_FORM_ALIAS_TO_CANONICAL = {
"flash": "flash_attention_2",
"flex": "flex_attention",
"sdp": "sdpa",
}
# Backends that require the flash_attn library (Dao-AILab/flash-attention)
# for axolotl's own monkeypatches (FA4 auto-apply, LLaMA flash hijack, etc.)
FLASH_ATTN_LIB_IMPLS = frozenset({"flash", "s2"})
# Backends that support varlen sample packing via `position_ids`.
ATTN_IMPLS_SUPPORTING_PACKING = frozenset(
{
"flash_attention_2",
"flash_attention_3",
"flex_attention",
"xformers",
"sage",
"kernels-community/flash-attn3",
"kernels-community/sage-attention",
}
)
# Known backends that do NOT support varlen sample packing via position_ids.
# Used as an exclusion list: unknown strings (e.g., HF hub kernels like
# "kernels-community/flash-attn3") default to packing-capable.
_NON_PACKING_ATTN_IMPLS = frozenset({"eager", "sdpa", "s2", "fp8"})
# Backends that require the flash_attn library (Dao-AILab/flash-attention) for
# axolotl's own monkeypatches (FA4 auto-apply, LLaMA flash hijack, ring-FA, ...).
ATTN_IMPLS_USING_FLASH_LIB = frozenset(
{
"flash_attention_2",
"flash_attention_3",
"s2",
"kernels-community/flash-attn3",
}
)
# Known backends that do NOT need embedding dtype cast.
_NO_DTYPE_CAST_ATTN_IMPLS = frozenset({"eager", "sdpa"})
# Backends for which embeddings stay in fp32. Everything else needs fp16/bf16.
ATTN_IMPLS_WITHOUT_DTYPE_CAST = frozenset({"eager", "sdpa"})
class RingAttnFunc(str, Enum):

View File

@@ -13,7 +13,7 @@ from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import (
_NON_PACKING_ATTN_IMPLS,
ATTN_IMPLS_SUPPORTING_PACKING,
ChatTemplate,
RingAttnFunc,
RLType,
@@ -184,26 +184,8 @@ class DatasetValidationMixin:
class AttentionValidationMixin:
"""Validation methods related to attention mechanisms."""
@model_validator(mode="before")
@classmethod
def check_attention_fields(cls, data):
# If attn_implementation is set, the enum handles mutual exclusivity.
# This validator catches legacy configs with multiple boolean flags.
if data.get("attn_implementation"):
return data
fields = (
"xformers_attention",
"sdp_attention",
# "s2_attention", # requires both FA and this to be enabled
"flash_attention",
"flex_attention",
"sage_attention",
)
non_empty_count = sum(1 for field in fields if data.get(field))
if non_empty_count > 1:
raise ValueError(f"Only one of {', '.join(fields)} must be set")
return data
# `check_attention_fields` was removed — `AxolotlInputConfig.normalize_attn_implementation`
# is now the single entry point for attention-input mapping and conflict detection.
@model_validator(mode="before")
@classmethod
@@ -238,7 +220,8 @@ class AttentionValidationMixin:
@classmethod
def check_scaling_softmax_requires_flex(cls, data):
if data.get("scaling_softmax") and not (
data.get("flex_attention") or data.get("attn_implementation") == "flex"
data.get("flex_attention")
or data.get("attn_implementation") == "flex_attention"
):
raise ValueError(
"scaling_softmax requires flex attention.\n"
@@ -956,7 +939,7 @@ class OptimizationValidationMixin:
if data.get("batch_flattening"):
batch_flattening_auto = data.get("batch_flattening") == "auto"
has_varlen_attn = (
data.get("attn_implementation") not in _NON_PACKING_ATTN_IMPLS
data.get("attn_implementation") in ATTN_IMPLS_SUPPORTING_PACKING
if data.get("attn_implementation")
else data.get("flash_attention")
)
@@ -1683,7 +1666,8 @@ class EBFTValidationMixin:
data.get("rl") == "ebft"
and data.get("ebft", {}).get("mode") == "strided"
and (
data.get("flex_attention") or data.get("attn_implementation") == "flex"
data.get("flex_attention")
or data.get("attn_implementation") == "flex_attention"
)
and data.get("gradient_checkpointing")
):

View File

@@ -1,272 +1,205 @@
"""
Tests for attn_implementation normalization, registry registration,
capability properties, and backwards compatibility with legacy boolean
attention flags.
"""Tests for attn_implementation normalization and capability computation.
Covers the Phase 1 contract:
- `attn_implementation` accepts canonical names only; short-form aliases are rejected.
- Legacy boolean flags are mapped to the canonical value, warned on, and stripped.
- Canonical `attn_implementation` + legacy flag raises.
- Capability flags are computed from `attn_implementation`.
"""
import pytest
from axolotl.utils.schemas.config import AxolotlInputConfig
from axolotl.utils.schemas.enums import (
_NO_DTYPE_CAST_ATTN_IMPLS,
_NON_PACKING_ATTN_IMPLS,
FLASH_ATTN_LIB_IMPLS,
ATTN_IMPLS_SUPPORTING_PACKING,
ATTN_IMPLS_USING_FLASH_LIB,
ATTN_IMPLS_WITHOUT_DTYPE_CAST,
CANONICAL_ATTN_IMPLS,
)
class TestAttnImplementationNormalizer:
"""Test the normalize_attn_implementation validator."""
class TestNormalizerLegacyMapping:
"""Legacy boolean flags map to canonical attn_implementation."""
@staticmethod
def _normalize(data):
return AxolotlInputConfig.normalize_attn_implementation(data)
# --- Forward mapping: attn_implementation -> legacy flags ---
@pytest.mark.parametrize(
"impl,expected_flag",
"flag,expected",
[
("eager", "eager_attention"),
("flash", "flash_attention"),
("sdpa", "sdp_attention"),
("flex", "flex_attention"),
("xformers", "xformers_attention"),
("sage", "sage_attention"),
("s2", "s2_attention"),
],
)
def test_attn_impl_sets_primary_legacy_flag(self, impl, expected_flag):
data = {"attn_implementation": impl}
result = AxolotlInputConfig.normalize_attn_implementation(data)
assert result.get(expected_flag) is True, (
f"{impl}: expected {expected_flag}=True"
)
@pytest.mark.parametrize("impl", ["xformers", "sage", "s2"])
def test_attn_impl_does_not_set_flash_for_non_flash(self, impl):
"""xformers, sage, s2 should NOT set flash_attention=True anymore."""
result = self._normalize({"attn_implementation": impl})
assert not result.get("flash_attention"), (
f"{impl} should not set flash_attention"
)
def test_fp8_sets_no_legacy_flags(self):
result = self._normalize({"attn_implementation": "fp8"})
for flag in [
"flash_attention",
"sdp_attention",
"eager_attention",
"xformers_attention",
"sage_attention",
"flex_attention",
"s2_attention",
]:
assert not result.get(flag), f"fp8 should not set {flag}"
# --- Reverse mapping: legacy flags -> attn_implementation ---
@pytest.mark.parametrize(
"flag,expected_impl",
[
("flash_attention", "flash"),
("flash_attention", "flash_attention_2"),
("sdp_attention", "sdpa"),
("xformers_attention", "xformers"),
("flex_attention", "flex"),
("flex_attention", "flex_attention"),
("sage_attention", "sage"),
("eager_attention", "eager"),
("s2_attention", "s2"),
],
)
def test_legacy_flag_sets_attn_impl(self, flag, expected_impl):
def test_legacy_flag_maps_to_canonical(self, flag, expected):
result = self._normalize({flag: True})
assert result["attn_implementation"] == expected_impl
assert result["attn_implementation"] == expected
# --- Priority: s2/sage should win over flash when both set ---
def test_s2_plus_flash_maps_to_s2(self):
"""Legacy configs often have both s2_attention and flash_attention."""
result = self._normalize({"s2_attention": True, "flash_attention": True})
assert result["attn_implementation"] == "s2"
def test_sage_plus_flash_maps_to_sage(self):
"""sage_attention should take priority over flash_attention."""
result = self._normalize({"sage_attention": True, "flash_attention": True})
assert result["attn_implementation"] == "sage"
# --- Consistency: both set, matching ---
def test_consistent_both_set_no_error(self):
result = self._normalize(
{"attn_implementation": "flash", "flash_attention": True}
)
assert result["attn_implementation"] == "flash"
assert result["flash_attention"] is True
def test_consistent_xformers_with_own_flag(self):
"""xformers + xformers_attention should be OK."""
result = self._normalize(
{"attn_implementation": "xformers", "xformers_attention": True}
)
assert result["attn_implementation"] == "xformers"
# --- Conflict detection ---
def test_conflicting_impl_and_flag_raises(self):
with pytest.raises(ValueError, match="conflicts with"):
self._normalize({"attn_implementation": "flash", "sdp_attention": True})
def test_conflicting_xformers_impl_with_sdp_flag(self):
with pytest.raises(ValueError, match="conflicts with"):
self._normalize({"attn_implementation": "xformers", "sdp_attention": True})
def test_xformers_with_flash_flag_conflicts(self):
"""After normalizer change, xformers no longer expects flash_attention."""
with pytest.raises(ValueError, match="conflicts with"):
self._normalize(
{
"attn_implementation": "xformers",
"xformers_attention": True,
"flash_attention": True,
}
)
def test_s2_with_flash_flag_conflicts(self):
"""After normalizer change, s2 no longer expects flash_attention."""
with pytest.raises(ValueError, match="conflicts with"):
self._normalize(
{
"attn_implementation": "s2",
"s2_attention": True,
"flash_attention": True,
}
)
# --- Hub kernel strings pass through ---
def test_hub_kernel_passthrough(self):
result = self._normalize(
{"attn_implementation": "kernels-community/flash-attn3"}
)
assert result["attn_implementation"] == "kernels-community/flash-attn3"
# Should not set any legacy flags
def test_legacy_flags_are_stripped_after_mapping(self):
result = self._normalize({"flash_attention": True})
for flag in [
"flash_attention",
"sdp_attention",
"eager_attention",
"xformers_attention",
"flex_attention",
"sage_attention",
"eager_attention",
"s2_attention",
]:
assert not result.get(flag)
assert flag not in result
def test_custom_string_passthrough(self):
result = self._normalize({"attn_implementation": "my_custom_kernel"})
assert result["attn_implementation"] == "my_custom_kernel"
def test_s2_plus_flash_priority_is_s2(self):
result = self._normalize({"s2_attention": True, "flash_attention": True})
assert result["attn_implementation"] == "s2"
# --- No attention set ---
def test_sage_plus_flash_priority_is_sage(self):
result = self._normalize({"sage_attention": True, "flash_attention": True})
assert result["attn_implementation"] == "sage"
class TestNormalizerConflicts:
"""Canonical attn_implementation + legacy flag raises."""
@staticmethod
def _normalize(data):
return AxolotlInputConfig.normalize_attn_implementation(data)
def test_canonical_plus_legacy_flag_raises(self):
with pytest.raises(ValueError, match="cannot be combined with legacy"):
self._normalize(
{"attn_implementation": "flash_attention_2", "flash_attention": True}
)
def test_canonical_plus_unrelated_legacy_flag_raises(self):
with pytest.raises(ValueError, match="cannot be combined with legacy"):
self._normalize(
{"attn_implementation": "xformers", "flash_attention": True}
)
class TestNormalizerPassthrough:
"""Canonical values and hub-kernel paths pass through."""
def test_canonical_no_legacy_is_noop(self):
data = {"attn_implementation": "flash_attention_2"}
result = AxolotlInputConfig.normalize_attn_implementation(data)
assert result["attn_implementation"] == "flash_attention_2"
def test_hub_kernel_passes_through(self):
data = {"attn_implementation": "kernels-community/flash-attn3"}
result = AxolotlInputConfig.normalize_attn_implementation(data)
assert result["attn_implementation"] == "kernels-community/flash-attn3"
def test_no_attention_set_is_noop(self):
result = self._normalize({"some_other_config": True})
result = AxolotlInputConfig.normalize_attn_implementation(
{"some_other_config": True}
)
assert result.get("attn_implementation") is None
# --- Gemma4 hybrid ---
def test_gemma4_hybrid_sets_flash(self):
"""gemma4_hybrid_attn_impl should default attn_implementation to flash."""
result = self._normalize({"gemma4_hybrid_attn_impl": True})
assert result["attn_implementation"] == "flash"
assert result["flash_attention"] is True
class TestGemma4Hybrid:
"""gemma4_hybrid_attn_impl defaults to flash_attention_2."""
def test_gemma4_hybrid_does_not_override_explicit(self):
"""If attn_implementation is already set, gemma4 should not override it."""
result = self._normalize(
def test_gemma4_hybrid_defaults_to_fa2(self):
result = AxolotlInputConfig.normalize_attn_implementation(
{"gemma4_hybrid_attn_impl": True}
)
assert result["attn_implementation"] == "flash_attention_2"
def test_gemma4_hybrid_respects_explicit(self):
result = AxolotlInputConfig.normalize_attn_implementation(
{"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"}
)
assert result["attn_implementation"] == "sdpa"
class TestAttnCapabilityProperties:
"""Test the capability properties on the normalizer data.
class TestFieldValidator:
"""attn_implementation field_validator rejects short-form aliases."""
Since these are @property on AxolotlInputConfig (a Pydantic model),
we test the underlying logic directly using the constant sets.
"""
def test_canonical_accepted(self):
for impl in CANONICAL_ATTN_IMPLS:
assert AxolotlInputConfig.validate_attn_implementation(impl) == impl
# --- attn_supports_packing ---
def test_hub_kernel_accepted(self):
for impl in (
"kernels-community/flash-attn3",
"kernels-community/sage-attention",
"someorg/custom-kernel",
):
assert AxolotlInputConfig.validate_attn_implementation(impl) == impl
@pytest.mark.parametrize("impl", ["flash", "flex", "xformers", "sage"])
def test_supports_packing_true(self, impl):
assert impl not in _NON_PACKING_ATTN_IMPLS
def test_none_accepted(self):
assert AxolotlInputConfig.validate_attn_implementation(None) is None
@pytest.mark.parametrize("impl", ["eager", "sdpa", "s2", "fp8"])
def test_supports_packing_false(self, impl):
assert impl in _NON_PACKING_ATTN_IMPLS
@pytest.mark.parametrize("alias", ["flash", "flex", "sdp"])
def test_short_form_alias_rejected(self, alias):
with pytest.raises(ValueError, match="is not accepted"):
AxolotlInputConfig.validate_attn_implementation(alias)
def test_hub_kernel_supports_packing(self):
"""Unknown hub kernels should default to packing-capable."""
assert "kernels-community/flash-attn3" not in _NON_PACKING_ATTN_IMPLS
def test_unknown_without_slash_rejected(self):
with pytest.raises(ValueError, match="not a recognized backend"):
AxolotlInputConfig.validate_attn_implementation("not_a_real_backend")
# --- attn_uses_flash_lib ---
@pytest.mark.parametrize("impl", ["flash", "s2"])
def test_uses_flash_lib_true(self, impl):
assert impl in FLASH_ATTN_LIB_IMPLS
class TestCapabilityTables:
"""Capability tables are keyed by canonical names and cover the expected backends."""
@pytest.mark.parametrize(
"impl", ["eager", "sdpa", "xformers", "flex", "sage", "fp8"]
"impl",
[
"flash_attention_2",
"flash_attention_3",
"flex_attention",
"xformers",
"sage",
],
)
def test_uses_flash_lib_false(self, impl):
assert impl not in FLASH_ATTN_LIB_IMPLS
def test_supports_packing(self, impl):
assert impl in ATTN_IMPLS_SUPPORTING_PACKING
def test_hub_kernel_not_flash_lib(self):
"""Hub kernels are HF-managed, not axolotl monkeypatch targets."""
assert "kernels-community/flash-attn3" not in FLASH_ATTN_LIB_IMPLS
@pytest.mark.parametrize("impl", ["eager", "sdpa", "s2", "fp8"])
def test_does_not_support_packing(self, impl):
assert impl not in ATTN_IMPLS_SUPPORTING_PACKING
# --- attn_needs_dtype_cast ---
@pytest.mark.parametrize("impl", ["flash_attention_2", "flash_attention_3", "s2"])
def test_uses_flash_lib(self, impl):
assert impl in ATTN_IMPLS_USING_FLASH_LIB
@pytest.mark.parametrize(
"impl", ["eager", "sdpa", "xformers", "flex_attention", "sage", "fp8"]
)
def test_does_not_use_flash_lib(self, impl):
assert impl not in ATTN_IMPLS_USING_FLASH_LIB
@pytest.mark.parametrize("impl", ["eager", "sdpa"])
def test_no_dtype_cast(self, impl):
assert impl in _NO_DTYPE_CAST_ATTN_IMPLS
@pytest.mark.parametrize("impl", ["flash", "flex", "sage", "xformers", "s2", "fp8"])
def test_needs_dtype_cast(self, impl):
assert impl not in _NO_DTYPE_CAST_ATTN_IMPLS
class TestAttnImplToHFMapping:
"""Test that attn_implementation enum values map correctly to HF strings."""
# This dict mirrors _ATTN_IMPL_TO_HF in model.py
_ATTN_IMPL_TO_HF = {
"eager": "eager",
"flash": "flash_attention_2",
"sdpa": "sdpa",
"xformers": "xformers",
"flex": "flex_attention",
"sage": "sage",
"s2": "flash_attention_2",
"fp8": "sdpa",
}
assert impl in ATTN_IMPLS_WITHOUT_DTYPE_CAST
@pytest.mark.parametrize(
"impl,expected_hf",
"impl",
[
("eager", "eager"),
("flash", "flash_attention_2"),
("sdpa", "sdpa"),
("xformers", "xformers"),
("flex", "flex_attention"),
("sage", "sage"),
("s2", "flash_attention_2"),
("fp8", "sdpa"),
"flash_attention_2",
"flash_attention_3",
"flex_attention",
"xformers",
"sage",
"s2",
"fp8",
],
)
def test_known_impl_maps_correctly(self, impl, expected_hf):
assert self._ATTN_IMPL_TO_HF[impl] == expected_hf
def test_needs_dtype_cast(self, impl):
assert impl not in ATTN_IMPLS_WITHOUT_DTYPE_CAST
def test_hub_kernel_falls_through(self):
"""Hub kernel strings should pass through .get() unchanged."""
hub_str = "kernels-community/flash-attn3"
result = self._ATTN_IMPL_TO_HF.get(hub_str, hub_str)
assert result == hub_str
def test_known_hub_kernels_classified(self):
assert "kernels-community/flash-attn3" in ATTN_IMPLS_SUPPORTING_PACKING
assert "kernels-community/flash-attn3" in ATTN_IMPLS_USING_FLASH_LIB
assert "kernels-community/sage-attention" in ATTN_IMPLS_SUPPORTING_PACKING
def _xformers_available():
@@ -279,7 +212,7 @@ def _xformers_available():
class TestAttentionRegistration:
"""Test that attention backends register correctly in HF's registries."""
"""Axolotl-owned backends register under their canonical names in HF's registries."""
@pytest.mark.skipif(not _xformers_available(), reason="xformers not available")
def test_register_xformers(self):
@@ -292,7 +225,6 @@ class TestAttentionRegistration:
assert "xformers" in ALL_ATTENTION_FUNCTIONS
assert "xformers" in ALL_MASK_ATTENTION_FUNCTIONS
# xformers mask should be the same function as flash_attention_2's mask
assert (
ALL_MASK_ATTENTION_FUNCTIONS["xformers"]
== ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
@@ -315,7 +247,6 @@ class TestAttentionRegistration:
@pytest.mark.skipif(not _xformers_available(), reason="xformers not available")
def test_xformers_does_not_overwrite_fa2(self):
"""Registering xformers should not modify the flash_attention_2 slot."""
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
original_fa2 = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]
@@ -327,7 +258,6 @@ class TestAttentionRegistration:
assert ALL_ATTENTION_FUNCTIONS["flash_attention_2"] is original_fa2
def test_sage_does_not_overwrite_fa2(self):
"""Registering sage should not modify the flash_attention_2 slot."""
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
original_fa2 = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]