replace legacy attention boolean flags with capability properties

Replace checks with capability-based properties derived from attn_implementation

This separates three concerns that were conflated under flash_attention:
1. Backend selection -> attn_implementation enum
2. Packing capability -> attn_supports_packing property
3. Flash-attn library dependency -> attn_uses_flash_lib property
This commit is contained in:
Wing Lian
2026-04-12 22:01:09 -04:00
parent aee8c75d64
commit ff5d6393c8
13 changed files with 274 additions and 136 deletions

View File

@@ -147,7 +147,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
load_in_8bit=False,
load_in_4bit=False,
quantize_moe_experts=False,
flash_attention=False,
attn_implementation=None,
context_parallel_size=None,
deepspeed=None,
fsdp=None,

View File

@@ -257,19 +257,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
training_arguments_kwargs["sample_packing_drop_attention_mask"] = bool(
self.cfg.flash_attention
or self.cfg.xformers_attention
or self.cfg.flex_attention
training_arguments_kwargs["sample_packing_drop_attention_mask"] = (
self.cfg.attn_supports_packing
)
training_arguments_kwargs["multipack_real_batches"] = (
self.cfg.multipack_real_batches
if self.cfg.multipack_real_batches is not None
else not (
self.cfg.flash_attention
or self.cfg.flex_attention
or self.cfg.xformers_attention
)
else not self.cfg.attn_supports_packing
)
training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing
@@ -508,11 +502,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
# supported multipack models, or non-flash-attention llama
if (
self.cfg.flex_attention
self.cfg.attn_implementation == "flex"
or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
or (
self.cfg.model_config_type in ["llama"]
and self.cfg.flash_attention is not True
and self.cfg.attn_implementation != "flash"
)
):
collator = V2BatchSamplerDataCollatorForSeq2Seq

View File

@@ -23,7 +23,7 @@ class LMEvalPlugin(BasePlugin):
for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks,
bfloat16=cfg.bfloat16 or cfg.bf16,
flash_attention=cfg.flash_attention,
flash_attention=(cfg.attn_implementation == "flash"),
output_dir=cfg.output_dir,
batch_size=cfg.lm_eval_batch_size,
wandb_project=cfg.wandb_project,

View File

@@ -383,7 +383,9 @@ class SwanLabPlugin(BasePlugin):
"seed": safe_convert(getattr(cfg, "seed", None)),
"bf16": safe_convert(getattr(cfg, "bf16", None)),
"tf32": safe_convert(getattr(cfg, "tf32", None)),
"flash_attention": safe_convert(getattr(cfg, "flash_attention", None)),
"attn_implementation": safe_convert(
getattr(cfg, "attn_implementation", None)
),
"sample_packing": safe_convert(getattr(cfg, "sample_packing", None)),
}

View File

@@ -343,12 +343,7 @@ class ModelLoader:
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so
# we need to convert them back to fp16/bf16 for flash-attn compatibility.
(
(
needs_fa2_dtype
or self.cfg.flash_attention
or self.cfg.flex_attention
or self.cfg.sage_attention
)
(needs_fa2_dtype or self.cfg.attn_needs_dtype_cast)
and not self.is_qlora_and_fsdp_enabled
)
or (
@@ -656,32 +651,12 @@ class ModelLoader:
# global layers will be patched to sdpa post-load.
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = "flash_attention_2"
# Set flash_attention so multipack/sample_packing patches activate
self.cfg.flash_attention = True
elif self.cfg.attn_implementation:
hf_impl = _ATTN_IMPL_TO_HF.get(
self.cfg.attn_implementation, self.cfg.attn_implementation
)
self.model_kwargs["attn_implementation"] = hf_impl
self.model_config._attn_implementation = hf_impl
elif self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"
self.model_config._attn_implementation = "flex_attention"
elif self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = "flash_attention_2"
elif self.cfg.sdp_attention:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = "sdpa"
elif self.cfg.sage_attention:
# sets FA2 attention to re-use same internal handling like masking
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = "flash_attention_2"
elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = "eager"
if self.cfg.low_cpu_mem_usage:
self.model_kwargs["low_cpu_mem_usage"] = True

View File

@@ -253,7 +253,7 @@ class PatchManager:
def _apply_flash_attention_patches(self):
"""Apply patches related to Flash Attention."""
if self.cfg.xformers_attention:
if self.cfg.attn_implementation == "xformers":
from axolotl.monkeypatch.attention import register_xformers_attn
register_xformers_attn()
@@ -263,9 +263,8 @@ class PatchManager:
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
patch_xformers_attn_over_fa2()
self.cfg.flash_attention = True
if self.cfg.sage_attention:
if self.cfg.attn_implementation == "sage":
from axolotl.monkeypatch.attention import register_sage_attn
register_sage_attn()
@@ -334,7 +333,7 @@ class PatchManager:
def _apply_flex_attention_patches(self):
"""Apply patches for flexible attention."""
if self.cfg.flex_attention:
if self.cfg.attn_implementation == "flex":
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_wrapper,
)
@@ -344,14 +343,14 @@ class PatchManager:
def _apply_sageattn_patches(self):
"""Apply patches for SageAttention."""
if self.cfg.sage_attention:
if self.cfg.attn_implementation == "sage":
from axolotl.monkeypatch.attention.sage_attn import patch_sageattn
patch_sageattn()
def _apply_flash_attn_4_patches(self):
"""Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+."""
if not self.cfg.flash_attention:
if not self.cfg.attn_uses_flash_lib:
return
from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4
@@ -420,7 +419,7 @@ class PatchManager:
if (
self.cfg.model_config_type in ["qwen3_5", "qwen3_5_moe"]
and self.cfg.is_multimodal
and self.cfg.flash_attention
and self.cfg.attn_uses_flash_lib
):
from axolotl.monkeypatch.models.qwen3_5.modeling import (
patch_qwen3_5_vlm_flash_attention,
@@ -572,7 +571,7 @@ class PatchManager:
"""Apply multipack patches if necessary."""
if (
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and (self.cfg.flash_attention or self.cfg.flex_attention)
and self.cfg.attn_supports_packing
and self.cfg.sample_packing
):
# Get automap config if it exists
@@ -693,7 +692,9 @@ class PatchManager:
def _patch_attention(self):
"""Apply attention-specific patches based on model type."""
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
if not (
self.cfg.attn_uses_flash_lib and hasattr(self.model_config, "model_type")
):
return
if self.model_config.model_type == "btlm":
@@ -739,7 +740,7 @@ class PatchManager:
replace_llama_attn_with_flash_attn,
)
if self.cfg.s2_attention:
if self.cfg.attn_implementation == "s2":
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
replace_llama_attn_with_flash_attn(
cross_entropy=self.cfg.flash_attn_cross_entropy,
@@ -765,14 +766,14 @@ class PatchManager:
"""Modify all llama derived models in one block."""
if self.cfg.is_llama_derived_model and not (
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and (self.cfg.flash_attention or self.cfg.flex_attention)
and self.cfg.attn_supports_packing
and self.cfg.sample_packing
):
if self.cfg.flash_attention:
if self.cfg.attn_uses_flash_lib:
self._patch_llama_flash_attention()
elif self.cfg.xformers_attention:
elif self.cfg.attn_implementation == "xformers":
self._patch_llama_xformers_attention()
elif self.cfg.s2_attention:
elif self.cfg.attn_implementation == "s2":
raise NotImplementedError(
"Shifted-sparse attention not currently implemented without flash attention."
)
@@ -784,7 +785,7 @@ class PatchManager:
in ["llama", "llama4", "ernie4_5", "ernie4_5_moe"]
and not self.cfg.trust_remote_code
and not self.cfg.gptq
and self.cfg.flash_attention
and self.cfg.attn_uses_flash_lib
and is_flash_attn_available()
and not self.inference
):

View File

@@ -205,7 +205,11 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Mistral's official FA implementation requires left padding
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
if (
cfg.is_mistral_derived_model
and cfg.attn_implementation == "flash"
and not cfg.sample_packing
):
tokenizer.padding_side = "left"
# Qwen base only has single token, so we need to set the special tokens

View File

@@ -955,9 +955,9 @@ def colab_inference_post_train_callback(trainer: Trainer):
"""
handle T4 gpu, we need to convert attention to eager for inference
"""
if "Tesla T4" in self.gpu_name and (
self.cfg.xformers_attention
or self.cfg.attn_implementation == "xformers"
if (
"Tesla T4" in self.gpu_name
and self.cfg.attn_implementation == "xformers"
):
trainer.model.config._attn_implementation = "eager"
trainer.model.gradient_checkpointing_disable()

View File

@@ -28,6 +28,9 @@ from axolotl.utils.schemas.datasets import (
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
from axolotl.utils.schemas.enums import (
_NO_DTYPE_CAST_ATTN_IMPLS,
_NON_PACKING_ATTN_IMPLS,
FLASH_ATTN_LIB_IMPLS,
AttnImplementation,
ChatTemplate,
RingAttnFunc,
@@ -1332,6 +1335,40 @@ class AxolotlInputConfig(
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
return None
# --- Attention capability properties ---
@property
def attn_supports_packing(self) -> bool:
"""True if attention supports varlen sample packing via position_ids.
Known varlen backends: flash, flex, xformers, sage.
Unknown strings (e.g., hub kernels like 'kernels-community/flash-attn3')
default to True since they generally support varlen.
"""
if not self.attn_implementation:
return False
return self.attn_implementation not in _NON_PACKING_ATTN_IMPLS
@property
def attn_uses_flash_lib(self) -> bool:
"""True if the backend uses axolotl's flash_attn monkeypatches.
Only for axolotl-managed FA setup (flash, s2). Hub kernels are
HF-managed and don't need these patches.
"""
return self.attn_implementation in FLASH_ATTN_LIB_IMPLS
@property
def attn_needs_dtype_cast(self) -> bool:
"""True if attention needs embedding dtype cast to fp16/bf16.
Unknown backends (hub kernels) default to True (safe -- harmless
if unnecessary, but missing cast causes errors).
"""
if not self.attn_implementation:
return False
return self.attn_implementation not in _NO_DTYPE_CAST_ATTN_IMPLS
@model_validator(mode="before")
@classmethod
def warn_peft_trainable_token_to_fix_untrained(cls, data):
@@ -1358,16 +1395,22 @@ class AxolotlInputConfig(
"""Normalize attention config: map between attn_implementation enum and legacy boolean flags."""
attn_impl = data.get("attn_implementation")
# Mapping: attn_implementation value -> (primary flag, extra flags to set)
impl_to_flags = {
"eager": (("eager_attention",), ()),
"flash": (("flash_attention",), ()),
"sdpa": (("sdp_attention",), ()),
"xformers": (("xformers_attention",), ("flash_attention",)),
"flex": (("flex_attention",), ()),
"sage": (("sage_attention",), ("flash_attention",)),
"s2": (("s2_attention",), ("flash_attention",)),
"fp8": ((), ()), # new, no legacy flags
# If gemma4_hybrid_attn_impl is set but no attn_implementation, default
# to flash (the sliding-window layers use FA2, and packing should be enabled).
if data.get("gemma4_hybrid_attn_impl") and not attn_impl:
data["attn_implementation"] = "flash"
attn_impl = "flash"
# Mapping: attn_implementation value -> primary legacy flag to set
impl_to_flag = {
"eager": "eager_attention",
"flash": "flash_attention",
"sdpa": "sdp_attention",
"xformers": "xformers_attention",
"flex": "flex_attention",
"sage": "sage_attention",
"s2": "s2_attention",
"fp8": None, # new, no legacy flag
}
# Reverse mapping: legacy flag -> attn_implementation value
@@ -1386,26 +1429,21 @@ class AxolotlInputConfig(
if attn_impl and set_flags:
# Both set — check consistency
if attn_impl in impl_to_flags:
expected_primary, expected_extra = impl_to_flags[attn_impl]
expected_flags = set(expected_primary) | set(expected_extra)
for flag in set_flags:
if flag not in expected_flags:
raise ValueError(
f"attn_implementation={attn_impl!r} conflicts with {flag}=true. "
f"Use only attn_implementation or the legacy flag, not both."
)
expected_flag = impl_to_flag.get(attn_impl)
for flag in set_flags:
if flag != expected_flag:
raise ValueError(
f"attn_implementation={attn_impl!r} conflicts with {flag}=true. "
f"Use only attn_implementation or the legacy flag, not both."
)
elif attn_impl and not set_flags:
# attn_implementation set, no legacy flags — set them for backwards compat
if attn_impl in impl_to_flags:
primary, extra = impl_to_flags[attn_impl]
for flag in (*primary, *extra):
data[flag] = True
# attn_implementation set, no legacy flags — set primary for backwards compat
flag = impl_to_flag.get(attn_impl)
if flag:
data[flag] = True
elif not attn_impl and set_flags:
# Legacy flags set, no attn_implementation — map to enum, warn
# Priority: specific backends first, then generic flash/sdp/eager
# s2 and sage require flash_attention internally, so they must be
# checked before flash_attention to avoid masking
priority = [
"xformers_attention",
"s2_attention",
@@ -1430,7 +1468,10 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_sageattn_wo_sample_packing(cls, data):
if (not data.get("sample_packing", False)) and data.get("sage_attention"):
is_sage = (
data.get("sage_attention") or data.get("attn_implementation") == "sage"
)
if (not data.get("sample_packing", False)) and is_sage:
if not data.get("pad_to_sequence_len", False):
LOG.warning(
"We recommend turning on `pad_to_sequence_len` for SageAttention without packing."
@@ -1441,7 +1482,10 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_sageattn_fft(cls, data):
if (not data.get("adapter", False)) and data.get("sage_attention"):
is_sage = (
data.get("sage_attention") or data.get("attn_implementation") == "sage"
)
if (not data.get("adapter", False)) and is_sage:
LOG.warning(
"We found loss to drop to 0 with SageAttention full finetuning."
"Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method."
@@ -1531,7 +1575,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
if (
data.get("sample_packing")
and data.get("sdp_attention")
and (data.get("sdp_attention") or data.get("attn_implementation") == "sdpa")
and (data.get("bfloat16") or data.get("bf16"))
and not is_sm_90
):
@@ -1546,8 +1590,11 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
@model_validator(mode="before")
@classmethod
def check_compute_capability_w_sageattn(cls, data):
is_sage = (
data.get("sage_attention") or data.get("attn_implementation") == "sage"
)
if (
data.get("sage_attention")
is_sage
and data.get("capabilities")
and data.get("capabilities").get("compute_capability")
not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"]
@@ -1715,7 +1762,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
@model_validator(mode="before")
@classmethod
def check_flex_torch_version(cls, data):
if (data.get("flex_attention") is not None) and (data.get("flex_attention")):
if data.get("flex_attention") or data.get("attn_implementation") == "flex":
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")

View File

@@ -110,6 +110,19 @@ class AttnImplementation(str, Enum):
fp8 = "fp8" # pylint: disable=invalid-name
# Backends that require the flash_attn library (Dao-AILab/flash-attention)
# for axolotl's own monkeypatches (FA4 auto-apply, LLaMA flash hijack, etc.)
FLASH_ATTN_LIB_IMPLS = frozenset({"flash", "s2"})
# Known backends that do NOT support varlen sample packing via position_ids.
# Used as an exclusion list: unknown strings (e.g., HF hub kernels like
# "kernels-community/flash-attn3") default to packing-capable.
_NON_PACKING_ATTN_IMPLS = frozenset({"eager", "sdpa", "s2", "fp8"})
# Known backends that do NOT need embedding dtype cast.
_NO_DTYPE_CAST_ATTN_IMPLS = frozenset({"eager", "sdpa"})
class RingAttnFunc(str, Enum):
"""Enum class for supported `ring-flash-attn` implementations"""

View File

@@ -12,7 +12,12 @@ from pydantic import (
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
from axolotl.utils.schemas.enums import (
_NON_PACKING_ATTN_IMPLS,
ChatTemplate,
RingAttnFunc,
RLType,
)
LOG = get_logger(__name__)
@@ -182,6 +187,10 @@ class AttentionValidationMixin:
@model_validator(mode="before")
@classmethod
def check_attention_fields(cls, data):
# If attn_implementation is set, the enum handles mutual exclusivity.
# This validator catches legacy configs with multiple boolean flags.
if data.get("attn_implementation"):
return data
fields = (
"xformers_attention",
"sdp_attention",
@@ -436,7 +445,7 @@ class TrainingValidationMixin:
not (self.bf16 or self.bfloat16)
and (self.fp16 or self.float16)
and not self.adapter
and not self.flash_attention
and not self.attn_uses_flash_lib
and self.sample_packing
):
LOG.warning(
@@ -946,8 +955,16 @@ class OptimizationValidationMixin:
def check_batch_flattening_fa(cls, data):
if data.get("batch_flattening"):
batch_flattening_auto = data.get("batch_flattening") == "auto"
if not data.get("flash_attention") and not batch_flattening_auto:
raise ValueError("batch_flattening requires flash attention")
has_varlen_attn = (
data.get("attn_implementation") not in _NON_PACKING_ATTN_IMPLS
if data.get("attn_implementation")
else data.get("flash_attention")
)
if not has_varlen_attn and not batch_flattening_auto:
raise ValueError(
"batch_flattening requires a varlen-capable attention backend "
"(e.g., attn_implementation: flash)"
)
if data.get("sample_packing") and not batch_flattening_auto:
raise ValueError("batch_flattening not compatible with sample_packing")
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:
@@ -966,7 +983,7 @@ class OptimizationValidationMixin:
if (
batch_flattening_auto
and data.get("flash_attention")
and has_varlen_attn
and not data.get("sample_packing")
and data.get("micro_batch_size") > 1
):
@@ -1211,6 +1228,12 @@ class SystemValidationMixin:
def check_npu_config(cls, data):
if is_torch_npu_available():
# check attention config
unsupported_npu_impls = {"flash", "sdpa", "s2"}
attn_impl = data.get("attn_implementation")
if attn_impl and attn_impl in unsupported_npu_impls:
raise NotImplementedError(
f"attn_implementation={attn_impl!r} is currently not supported on Ascend NPU."
)
attn_list = ["flash_attention", "sdp_attention", "s2_attention"]
for attn in attn_list:
if data.get(attn):
@@ -1519,9 +1542,10 @@ class ComplexValidationMixin:
if not self.context_parallel_size:
self.context_parallel_size = 1
elif self.context_parallel_size > 1:
if not self.flash_attention:
if not self.attn_uses_flash_lib:
raise ValueError(
"flash_attention: true must be set with context_parallel_size > 1"
"context_parallel_size > 1 requires flash attention "
"(attn_implementation: flash or s2)."
)
if self.sample_packing and self.micro_batch_size > 1:
@@ -1658,7 +1682,9 @@ class EBFTValidationMixin:
if (
data.get("rl") == "ebft"
and data.get("ebft", {}).get("mode") == "strided"
and data.get("flex_attention")
and (
data.get("flex_attention") or data.get("attn_implementation") == "flex"
)
and data.get("gradient_checkpointing")
):
gc_kwargs = data.get("gradient_checkpointing_kwargs") or {}

View File

@@ -462,7 +462,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
)
else:
if cfg.flash_attention and not cfg.multipack_real_batches:
if cfg.attn_supports_packing and not cfg.multipack_real_batches:
sampler_batch_size = 1
batch_max_len = cfg.micro_batch_size * cfg.sequence_len
else:

View File

@@ -1,11 +1,17 @@
"""
Tests for attn_implementation normalization, registry registration, and
backwards compatibility with legacy boolean attention flags.
Tests for attn_implementation normalization, registry registration,
capability properties, and backwards compatibility with legacy boolean
attention flags.
"""
import pytest
from axolotl.utils.schemas.config import AxolotlInputConfig
from axolotl.utils.schemas.enums import (
_NO_DTYPE_CAST_ATTN_IMPLS,
_NON_PACKING_ATTN_IMPLS,
FLASH_ATTN_LIB_IMPLS,
)
class TestAttnImplementationNormalizer:
@@ -18,22 +24,31 @@ class TestAttnImplementationNormalizer:
# --- Forward mapping: attn_implementation -> legacy flags ---
@pytest.mark.parametrize(
"impl,expected_flags",
"impl,expected_flag",
[
("eager", {"eager_attention": True}),
("flash", {"flash_attention": True}),
("sdpa", {"sdp_attention": True}),
("flex", {"flex_attention": True}),
("xformers", {"xformers_attention": True, "flash_attention": True}),
("sage", {"sage_attention": True, "flash_attention": True}),
("s2", {"s2_attention": True, "flash_attention": True}),
("eager", "eager_attention"),
("flash", "flash_attention"),
("sdpa", "sdp_attention"),
("flex", "flex_attention"),
("xformers", "xformers_attention"),
("sage", "sage_attention"),
("s2", "s2_attention"),
],
)
def test_attn_impl_sets_legacy_flags(self, impl, expected_flags):
def test_attn_impl_sets_primary_legacy_flag(self, impl, expected_flag):
data = {"attn_implementation": impl}
result = AxolotlInputConfig.normalize_attn_implementation(data)
for flag, val in expected_flags.items():
assert result.get(flag) == val, f"{impl}: expected {flag}={val}"
assert result.get(expected_flag) is True, (
f"{impl}: expected {expected_flag}=True"
)
@pytest.mark.parametrize("impl", ["xformers", "sage", "s2"])
def test_attn_impl_does_not_set_flash_for_non_flash(self, impl):
"""xformers, sage, s2 should NOT set flash_attention=True anymore."""
result = self._normalize({"attn_implementation": impl})
assert not result.get("flash_attention"), (
f"{impl} should not set flash_attention"
)
def test_fp8_sets_no_legacy_flags(self):
result = self._normalize({"attn_implementation": "fp8"})
@@ -87,27 +102,13 @@ class TestAttnImplementationNormalizer:
assert result["attn_implementation"] == "flash"
assert result["flash_attention"] is True
def test_consistent_xformers_with_extra_flags(self):
"""xformers needs flash_attention=True, so both flags with attn_impl should be OK."""
def test_consistent_xformers_with_own_flag(self):
"""xformers + xformers_attention should be OK."""
result = self._normalize(
{
"attn_implementation": "xformers",
"xformers_attention": True,
"flash_attention": True,
}
{"attn_implementation": "xformers", "xformers_attention": True}
)
assert result["attn_implementation"] == "xformers"
def test_consistent_s2_with_flash(self):
result = self._normalize(
{
"attn_implementation": "s2",
"s2_attention": True,
"flash_attention": True,
}
)
assert result["attn_implementation"] == "s2"
# --- Conflict detection ---
def test_conflicting_impl_and_flag_raises(self):
@@ -118,6 +119,28 @@ class TestAttnImplementationNormalizer:
with pytest.raises(ValueError, match="conflicts with"):
self._normalize({"attn_implementation": "xformers", "sdp_attention": True})
def test_xformers_with_flash_flag_conflicts(self):
"""After normalizer change, xformers no longer expects flash_attention."""
with pytest.raises(ValueError, match="conflicts with"):
self._normalize(
{
"attn_implementation": "xformers",
"xformers_attention": True,
"flash_attention": True,
}
)
def test_s2_with_flash_flag_conflicts(self):
"""After normalizer change, s2 no longer expects flash_attention."""
with pytest.raises(ValueError, match="conflicts with"):
self._normalize(
{
"attn_implementation": "s2",
"s2_attention": True,
"flash_attention": True,
}
)
# --- Hub kernel strings pass through ---
def test_hub_kernel_passthrough(self):
@@ -144,16 +167,69 @@ class TestAttnImplementationNormalizer:
result = self._normalize({"some_other_config": True})
assert result.get("attn_implementation") is None
# --- Sample packing interactions ---
# --- Gemma4 hybrid ---
def test_xformers_with_sample_packing_sets_flash(self):
"""xformers + sample_packing needs flash_attention=True for the patch chain."""
result = self._normalize(
{"attn_implementation": "xformers", "sample_packing": True}
)
assert result["xformers_attention"] is True
def test_gemma4_hybrid_sets_flash(self):
"""gemma4_hybrid_attn_impl should default attn_implementation to flash."""
result = self._normalize({"gemma4_hybrid_attn_impl": True})
assert result["attn_implementation"] == "flash"
assert result["flash_attention"] is True
def test_gemma4_hybrid_does_not_override_explicit(self):
"""If attn_implementation is already set, gemma4 should not override it."""
result = self._normalize(
{"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"}
)
assert result["attn_implementation"] == "sdpa"
class TestAttnCapabilityProperties:
"""Test the capability properties on the normalizer data.
Since these are @property on AxolotlInputConfig (a Pydantic model),
we test the underlying logic directly using the constant sets.
"""
# --- attn_supports_packing ---
@pytest.mark.parametrize("impl", ["flash", "flex", "xformers", "sage"])
def test_supports_packing_true(self, impl):
assert impl not in _NON_PACKING_ATTN_IMPLS
@pytest.mark.parametrize("impl", ["eager", "sdpa", "s2", "fp8"])
def test_supports_packing_false(self, impl):
assert impl in _NON_PACKING_ATTN_IMPLS
def test_hub_kernel_supports_packing(self):
"""Unknown hub kernels should default to packing-capable."""
assert "kernels-community/flash-attn3" not in _NON_PACKING_ATTN_IMPLS
# --- attn_uses_flash_lib ---
@pytest.mark.parametrize("impl", ["flash", "s2"])
def test_uses_flash_lib_true(self, impl):
assert impl in FLASH_ATTN_LIB_IMPLS
@pytest.mark.parametrize(
"impl", ["eager", "sdpa", "xformers", "flex", "sage", "fp8"]
)
def test_uses_flash_lib_false(self, impl):
assert impl not in FLASH_ATTN_LIB_IMPLS
def test_hub_kernel_not_flash_lib(self):
"""Hub kernels are HF-managed, not axolotl monkeypatch targets."""
assert "kernels-community/flash-attn3" not in FLASH_ATTN_LIB_IMPLS
# --- attn_needs_dtype_cast ---
@pytest.mark.parametrize("impl", ["eager", "sdpa"])
def test_no_dtype_cast(self, impl):
assert impl in _NO_DTYPE_CAST_ATTN_IMPLS
@pytest.mark.parametrize("impl", ["flash", "flex", "sage", "xformers", "s2", "fp8"])
def test_needs_dtype_cast(self, impl):
assert impl not in _NO_DTYPE_CAST_ATTN_IMPLS
class TestAttnImplToHFMapping:
"""Test that attn_implementation enum values map correctly to HF strings."""