replace legacy attention boolean flags with capability properties

Replace checks with capability-based properties derived from attn_implementation

This separates three concerns that were conflated under flash_attention:
1. Backend selection -> attn_implementation enum
2. Packing capability -> attn_supports_packing property
3. Flash-attn library dependency -> attn_uses_flash_lib property
This commit is contained in:
Wing Lian
2026-04-12 22:01:09 -04:00
parent aee8c75d64
commit ff5d6393c8
13 changed files with 274 additions and 136 deletions

View File

@@ -147,7 +147,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
load_in_8bit=False, load_in_8bit=False,
load_in_4bit=False, load_in_4bit=False,
quantize_moe_experts=False, quantize_moe_experts=False,
flash_attention=False, attn_implementation=None,
context_parallel_size=None, context_parallel_size=None,
deepspeed=None, deepspeed=None,
fsdp=None, fsdp=None,

View File

@@ -257,19 +257,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing) training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
training_arguments_kwargs["sample_packing_drop_attention_mask"] = bool( training_arguments_kwargs["sample_packing_drop_attention_mask"] = (
self.cfg.flash_attention self.cfg.attn_supports_packing
or self.cfg.xformers_attention
or self.cfg.flex_attention
) )
training_arguments_kwargs["multipack_real_batches"] = ( training_arguments_kwargs["multipack_real_batches"] = (
self.cfg.multipack_real_batches self.cfg.multipack_real_batches
if self.cfg.multipack_real_batches is not None if self.cfg.multipack_real_batches is not None
else not ( else not self.cfg.attn_supports_packing
self.cfg.flash_attention
or self.cfg.flex_attention
or self.cfg.xformers_attention
)
) )
training_arguments_kwargs["eval_sample_packing"] = bool( training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing self.cfg.eval_sample_packing
@@ -508,11 +502,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention, # Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
# supported multipack models, or non-flash-attention llama # supported multipack models, or non-flash-attention llama
if ( if (
self.cfg.flex_attention self.cfg.attn_implementation == "flex"
or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
or ( or (
self.cfg.model_config_type in ["llama"] self.cfg.model_config_type in ["llama"]
and self.cfg.flash_attention is not True and self.cfg.attn_implementation != "flash"
) )
): ):
collator = V2BatchSamplerDataCollatorForSeq2Seq collator = V2BatchSamplerDataCollatorForSeq2Seq

View File

@@ -23,7 +23,7 @@ class LMEvalPlugin(BasePlugin):
for lm_eval_args in build_lm_eval_command( for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks, cfg.lm_eval_tasks,
bfloat16=cfg.bfloat16 or cfg.bf16, bfloat16=cfg.bfloat16 or cfg.bf16,
flash_attention=cfg.flash_attention, flash_attention=(cfg.attn_implementation == "flash"),
output_dir=cfg.output_dir, output_dir=cfg.output_dir,
batch_size=cfg.lm_eval_batch_size, batch_size=cfg.lm_eval_batch_size,
wandb_project=cfg.wandb_project, wandb_project=cfg.wandb_project,

View File

@@ -383,7 +383,9 @@ class SwanLabPlugin(BasePlugin):
"seed": safe_convert(getattr(cfg, "seed", None)), "seed": safe_convert(getattr(cfg, "seed", None)),
"bf16": safe_convert(getattr(cfg, "bf16", None)), "bf16": safe_convert(getattr(cfg, "bf16", None)),
"tf32": safe_convert(getattr(cfg, "tf32", None)), "tf32": safe_convert(getattr(cfg, "tf32", None)),
"flash_attention": safe_convert(getattr(cfg, "flash_attention", None)), "attn_implementation": safe_convert(
getattr(cfg, "attn_implementation", None)
),
"sample_packing": safe_convert(getattr(cfg, "sample_packing", None)), "sample_packing": safe_convert(getattr(cfg, "sample_packing", None)),
} }

View File

@@ -343,12 +343,7 @@ class ModelLoader:
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so
# we need to convert them back to fp16/bf16 for flash-attn compatibility. # we need to convert them back to fp16/bf16 for flash-attn compatibility.
( (
( (needs_fa2_dtype or self.cfg.attn_needs_dtype_cast)
needs_fa2_dtype
or self.cfg.flash_attention
or self.cfg.flex_attention
or self.cfg.sage_attention
)
and not self.is_qlora_and_fsdp_enabled and not self.is_qlora_and_fsdp_enabled
) )
or ( or (
@@ -656,32 +651,12 @@ class ModelLoader:
# global layers will be patched to sdpa post-load. # global layers will be patched to sdpa post-load.
self.model_kwargs["attn_implementation"] = "flash_attention_2" self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = "flash_attention_2" self.model_config._attn_implementation = "flash_attention_2"
# Set flash_attention so multipack/sample_packing patches activate
self.cfg.flash_attention = True
elif self.cfg.attn_implementation: elif self.cfg.attn_implementation:
hf_impl = _ATTN_IMPL_TO_HF.get( hf_impl = _ATTN_IMPL_TO_HF.get(
self.cfg.attn_implementation, self.cfg.attn_implementation self.cfg.attn_implementation, self.cfg.attn_implementation
) )
self.model_kwargs["attn_implementation"] = hf_impl self.model_kwargs["attn_implementation"] = hf_impl
self.model_config._attn_implementation = hf_impl self.model_config._attn_implementation = hf_impl
elif self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"
self.model_config._attn_implementation = "flex_attention"
elif self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = "flash_attention_2"
elif self.cfg.sdp_attention:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = "sdpa"
elif self.cfg.sage_attention:
# sets FA2 attention to re-use same internal handling like masking
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = "flash_attention_2"
elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = "eager"
if self.cfg.low_cpu_mem_usage: if self.cfg.low_cpu_mem_usage:
self.model_kwargs["low_cpu_mem_usage"] = True self.model_kwargs["low_cpu_mem_usage"] = True

View File

@@ -253,7 +253,7 @@ class PatchManager:
def _apply_flash_attention_patches(self): def _apply_flash_attention_patches(self):
"""Apply patches related to Flash Attention.""" """Apply patches related to Flash Attention."""
if self.cfg.xformers_attention: if self.cfg.attn_implementation == "xformers":
from axolotl.monkeypatch.attention import register_xformers_attn from axolotl.monkeypatch.attention import register_xformers_attn
register_xformers_attn() register_xformers_attn()
@@ -263,9 +263,8 @@ class PatchManager:
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2 from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
patch_xformers_attn_over_fa2() patch_xformers_attn_over_fa2()
self.cfg.flash_attention = True
if self.cfg.sage_attention: if self.cfg.attn_implementation == "sage":
from axolotl.monkeypatch.attention import register_sage_attn from axolotl.monkeypatch.attention import register_sage_attn
register_sage_attn() register_sage_attn()
@@ -334,7 +333,7 @@ class PatchManager:
def _apply_flex_attention_patches(self): def _apply_flex_attention_patches(self):
"""Apply patches for flexible attention.""" """Apply patches for flexible attention."""
if self.cfg.flex_attention: if self.cfg.attn_implementation == "flex":
from axolotl.monkeypatch.attention.flex_attn import ( from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_wrapper, patch_flex_wrapper,
) )
@@ -344,14 +343,14 @@ class PatchManager:
def _apply_sageattn_patches(self): def _apply_sageattn_patches(self):
"""Apply patches for SageAttention.""" """Apply patches for SageAttention."""
if self.cfg.sage_attention: if self.cfg.attn_implementation == "sage":
from axolotl.monkeypatch.attention.sage_attn import patch_sageattn from axolotl.monkeypatch.attention.sage_attn import patch_sageattn
patch_sageattn() patch_sageattn()
def _apply_flash_attn_4_patches(self): def _apply_flash_attn_4_patches(self):
"""Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+.""" """Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+."""
if not self.cfg.flash_attention: if not self.cfg.attn_uses_flash_lib:
return return
from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4 from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4
@@ -420,7 +419,7 @@ class PatchManager:
if ( if (
self.cfg.model_config_type in ["qwen3_5", "qwen3_5_moe"] self.cfg.model_config_type in ["qwen3_5", "qwen3_5_moe"]
and self.cfg.is_multimodal and self.cfg.is_multimodal
and self.cfg.flash_attention and self.cfg.attn_uses_flash_lib
): ):
from axolotl.monkeypatch.models.qwen3_5.modeling import ( from axolotl.monkeypatch.models.qwen3_5.modeling import (
patch_qwen3_5_vlm_flash_attention, patch_qwen3_5_vlm_flash_attention,
@@ -572,7 +571,7 @@ class PatchManager:
"""Apply multipack patches if necessary.""" """Apply multipack patches if necessary."""
if ( if (
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and (self.cfg.flash_attention or self.cfg.flex_attention) and self.cfg.attn_supports_packing
and self.cfg.sample_packing and self.cfg.sample_packing
): ):
# Get automap config if it exists # Get automap config if it exists
@@ -693,7 +692,9 @@ class PatchManager:
def _patch_attention(self): def _patch_attention(self):
"""Apply attention-specific patches based on model type.""" """Apply attention-specific patches based on model type."""
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")): if not (
self.cfg.attn_uses_flash_lib and hasattr(self.model_config, "model_type")
):
return return
if self.model_config.model_type == "btlm": if self.model_config.model_type == "btlm":
@@ -739,7 +740,7 @@ class PatchManager:
replace_llama_attn_with_flash_attn, replace_llama_attn_with_flash_attn,
) )
if self.cfg.s2_attention: if self.cfg.attn_implementation == "s2":
LOG.info("patching w/ flash-enabled, shifted-sparse attention") LOG.info("patching w/ flash-enabled, shifted-sparse attention")
replace_llama_attn_with_flash_attn( replace_llama_attn_with_flash_attn(
cross_entropy=self.cfg.flash_attn_cross_entropy, cross_entropy=self.cfg.flash_attn_cross_entropy,
@@ -765,14 +766,14 @@ class PatchManager:
"""Modify all llama derived models in one block.""" """Modify all llama derived models in one block."""
if self.cfg.is_llama_derived_model and not ( if self.cfg.is_llama_derived_model and not (
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and (self.cfg.flash_attention or self.cfg.flex_attention) and self.cfg.attn_supports_packing
and self.cfg.sample_packing and self.cfg.sample_packing
): ):
if self.cfg.flash_attention: if self.cfg.attn_uses_flash_lib:
self._patch_llama_flash_attention() self._patch_llama_flash_attention()
elif self.cfg.xformers_attention: elif self.cfg.attn_implementation == "xformers":
self._patch_llama_xformers_attention() self._patch_llama_xformers_attention()
elif self.cfg.s2_attention: elif self.cfg.attn_implementation == "s2":
raise NotImplementedError( raise NotImplementedError(
"Shifted-sparse attention not currently implemented without flash attention." "Shifted-sparse attention not currently implemented without flash attention."
) )
@@ -784,7 +785,7 @@ class PatchManager:
in ["llama", "llama4", "ernie4_5", "ernie4_5_moe"] in ["llama", "llama4", "ernie4_5", "ernie4_5_moe"]
and not self.cfg.trust_remote_code and not self.cfg.trust_remote_code
and not self.cfg.gptq and not self.cfg.gptq
and self.cfg.flash_attention and self.cfg.attn_uses_flash_lib
and is_flash_attn_available() and is_flash_attn_available()
and not self.inference and not self.inference
): ):

View File

@@ -205,7 +205,11 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Mistral's official FA implementation requires left padding # Mistral's official FA implementation requires left padding
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing: if (
cfg.is_mistral_derived_model
and cfg.attn_implementation == "flash"
and not cfg.sample_packing
):
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
# Qwen base only has single token, so we need to set the special tokens # Qwen base only has single token, so we need to set the special tokens

View File

@@ -955,9 +955,9 @@ def colab_inference_post_train_callback(trainer: Trainer):
""" """
handle T4 gpu, we need to convert attention to eager for inference handle T4 gpu, we need to convert attention to eager for inference
""" """
if "Tesla T4" in self.gpu_name and ( if (
self.cfg.xformers_attention "Tesla T4" in self.gpu_name
or self.cfg.attn_implementation == "xformers" and self.cfg.attn_implementation == "xformers"
): ):
trainer.model.config._attn_implementation = "eager" trainer.model.config._attn_implementation = "eager"
trainer.model.gradient_checkpointing_disable() trainer.model.gradient_checkpointing_disable()

View File

@@ -28,6 +28,9 @@ from axolotl.utils.schemas.datasets import (
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
from axolotl.utils.schemas.enums import ( from axolotl.utils.schemas.enums import (
_NO_DTYPE_CAST_ATTN_IMPLS,
_NON_PACKING_ATTN_IMPLS,
FLASH_ATTN_LIB_IMPLS,
AttnImplementation, AttnImplementation,
ChatTemplate, ChatTemplate,
RingAttnFunc, RingAttnFunc,
@@ -1332,6 +1335,40 @@ class AxolotlInputConfig(
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs] return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
return None return None
# --- Attention capability properties ---
@property
def attn_supports_packing(self) -> bool:
"""True if attention supports varlen sample packing via position_ids.
Known varlen backends: flash, flex, xformers, sage.
Unknown strings (e.g., hub kernels like 'kernels-community/flash-attn3')
default to True since they generally support varlen.
"""
if not self.attn_implementation:
return False
return self.attn_implementation not in _NON_PACKING_ATTN_IMPLS
@property
def attn_uses_flash_lib(self) -> bool:
"""True if the backend uses axolotl's flash_attn monkeypatches.
Only for axolotl-managed FA setup (flash, s2). Hub kernels are
HF-managed and don't need these patches.
"""
return self.attn_implementation in FLASH_ATTN_LIB_IMPLS
@property
def attn_needs_dtype_cast(self) -> bool:
"""True if attention needs embedding dtype cast to fp16/bf16.
Unknown backends (hub kernels) default to True (safe -- harmless
if unnecessary, but missing cast causes errors).
"""
if not self.attn_implementation:
return False
return self.attn_implementation not in _NO_DTYPE_CAST_ATTN_IMPLS
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def warn_peft_trainable_token_to_fix_untrained(cls, data): def warn_peft_trainable_token_to_fix_untrained(cls, data):
@@ -1358,16 +1395,22 @@ class AxolotlInputConfig(
"""Normalize attention config: map between attn_implementation enum and legacy boolean flags.""" """Normalize attention config: map between attn_implementation enum and legacy boolean flags."""
attn_impl = data.get("attn_implementation") attn_impl = data.get("attn_implementation")
# Mapping: attn_implementation value -> (primary flag, extra flags to set) # If gemma4_hybrid_attn_impl is set but no attn_implementation, default
impl_to_flags = { # to flash (the sliding-window layers use FA2, and packing should be enabled).
"eager": (("eager_attention",), ()), if data.get("gemma4_hybrid_attn_impl") and not attn_impl:
"flash": (("flash_attention",), ()), data["attn_implementation"] = "flash"
"sdpa": (("sdp_attention",), ()), attn_impl = "flash"
"xformers": (("xformers_attention",), ("flash_attention",)),
"flex": (("flex_attention",), ()), # Mapping: attn_implementation value -> primary legacy flag to set
"sage": (("sage_attention",), ("flash_attention",)), impl_to_flag = {
"s2": (("s2_attention",), ("flash_attention",)), "eager": "eager_attention",
"fp8": ((), ()), # new, no legacy flags "flash": "flash_attention",
"sdpa": "sdp_attention",
"xformers": "xformers_attention",
"flex": "flex_attention",
"sage": "sage_attention",
"s2": "s2_attention",
"fp8": None, # new, no legacy flag
} }
# Reverse mapping: legacy flag -> attn_implementation value # Reverse mapping: legacy flag -> attn_implementation value
@@ -1386,26 +1429,21 @@ class AxolotlInputConfig(
if attn_impl and set_flags: if attn_impl and set_flags:
# Both set — check consistency # Both set — check consistency
if attn_impl in impl_to_flags: expected_flag = impl_to_flag.get(attn_impl)
expected_primary, expected_extra = impl_to_flags[attn_impl] for flag in set_flags:
expected_flags = set(expected_primary) | set(expected_extra) if flag != expected_flag:
for flag in set_flags: raise ValueError(
if flag not in expected_flags: f"attn_implementation={attn_impl!r} conflicts with {flag}=true. "
raise ValueError( f"Use only attn_implementation or the legacy flag, not both."
f"attn_implementation={attn_impl!r} conflicts with {flag}=true. " )
f"Use only attn_implementation or the legacy flag, not both."
)
elif attn_impl and not set_flags: elif attn_impl and not set_flags:
# attn_implementation set, no legacy flags — set them for backwards compat # attn_implementation set, no legacy flags — set primary for backwards compat
if attn_impl in impl_to_flags: flag = impl_to_flag.get(attn_impl)
primary, extra = impl_to_flags[attn_impl] if flag:
for flag in (*primary, *extra): data[flag] = True
data[flag] = True
elif not attn_impl and set_flags: elif not attn_impl and set_flags:
# Legacy flags set, no attn_implementation — map to enum, warn # Legacy flags set, no attn_implementation — map to enum, warn
# Priority: specific backends first, then generic flash/sdp/eager # Priority: specific backends first, then generic flash/sdp/eager
# s2 and sage require flash_attention internally, so they must be
# checked before flash_attention to avoid masking
priority = [ priority = [
"xformers_attention", "xformers_attention",
"s2_attention", "s2_attention",
@@ -1430,7 +1468,10 @@ class AxolotlInputConfig(
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_sageattn_wo_sample_packing(cls, data): def check_sageattn_wo_sample_packing(cls, data):
if (not data.get("sample_packing", False)) and data.get("sage_attention"): is_sage = (
data.get("sage_attention") or data.get("attn_implementation") == "sage"
)
if (not data.get("sample_packing", False)) and is_sage:
if not data.get("pad_to_sequence_len", False): if not data.get("pad_to_sequence_len", False):
LOG.warning( LOG.warning(
"We recommend turning on `pad_to_sequence_len` for SageAttention without packing." "We recommend turning on `pad_to_sequence_len` for SageAttention without packing."
@@ -1441,7 +1482,10 @@ class AxolotlInputConfig(
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_sageattn_fft(cls, data): def check_sageattn_fft(cls, data):
if (not data.get("adapter", False)) and data.get("sage_attention"): is_sage = (
data.get("sage_attention") or data.get("attn_implementation") == "sage"
)
if (not data.get("adapter", False)) and is_sage:
LOG.warning( LOG.warning(
"We found loss to drop to 0 with SageAttention full finetuning." "We found loss to drop to 0 with SageAttention full finetuning."
"Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method." "Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method."
@@ -1531,7 +1575,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
) )
if ( if (
data.get("sample_packing") data.get("sample_packing")
and data.get("sdp_attention") and (data.get("sdp_attention") or data.get("attn_implementation") == "sdpa")
and (data.get("bfloat16") or data.get("bf16")) and (data.get("bfloat16") or data.get("bf16"))
and not is_sm_90 and not is_sm_90
): ):
@@ -1546,8 +1590,11 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_compute_capability_w_sageattn(cls, data): def check_compute_capability_w_sageattn(cls, data):
is_sage = (
data.get("sage_attention") or data.get("attn_implementation") == "sage"
)
if ( if (
data.get("sage_attention") is_sage
and data.get("capabilities") and data.get("capabilities")
and data.get("capabilities").get("compute_capability") and data.get("capabilities").get("compute_capability")
not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"] not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"]
@@ -1715,7 +1762,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_flex_torch_version(cls, data): def check_flex_torch_version(cls, data):
if (data.get("flex_attention") is not None) and (data.get("flex_attention")): if data.get("flex_attention") or data.get("attn_implementation") == "flex":
env_capabilities = data.get("env_capabilities", {}) env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version") torch_version = env_capabilities.get("torch_version")

View File

@@ -110,6 +110,19 @@ class AttnImplementation(str, Enum):
fp8 = "fp8" # pylint: disable=invalid-name fp8 = "fp8" # pylint: disable=invalid-name
# Backends that require the flash_attn library (Dao-AILab/flash-attention)
# for axolotl's own monkeypatches (FA4 auto-apply, LLaMA flash hijack, etc.)
FLASH_ATTN_LIB_IMPLS = frozenset({"flash", "s2"})
# Known backends that do NOT support varlen sample packing via position_ids.
# Used as an exclusion list: unknown strings (e.g., HF hub kernels like
# "kernels-community/flash-attn3") default to packing-capable.
_NON_PACKING_ATTN_IMPLS = frozenset({"eager", "sdpa", "s2", "fp8"})
# Known backends that do NOT need embedding dtype cast.
_NO_DTYPE_CAST_ATTN_IMPLS = frozenset({"eager", "sdpa"})
class RingAttnFunc(str, Enum): class RingAttnFunc(str, Enum):
"""Enum class for supported `ring-flash-attn` implementations""" """Enum class for supported `ring-flash-attn` implementations"""

View File

@@ -12,7 +12,12 @@ from pydantic import (
from transformers.utils.import_utils import is_torch_npu_available from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType from axolotl.utils.schemas.enums import (
_NON_PACKING_ATTN_IMPLS,
ChatTemplate,
RingAttnFunc,
RLType,
)
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -182,6 +187,10 @@ class AttentionValidationMixin:
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_attention_fields(cls, data): def check_attention_fields(cls, data):
# If attn_implementation is set, the enum handles mutual exclusivity.
# This validator catches legacy configs with multiple boolean flags.
if data.get("attn_implementation"):
return data
fields = ( fields = (
"xformers_attention", "xformers_attention",
"sdp_attention", "sdp_attention",
@@ -436,7 +445,7 @@ class TrainingValidationMixin:
not (self.bf16 or self.bfloat16) not (self.bf16 or self.bfloat16)
and (self.fp16 or self.float16) and (self.fp16 or self.float16)
and not self.adapter and not self.adapter
and not self.flash_attention and not self.attn_uses_flash_lib
and self.sample_packing and self.sample_packing
): ):
LOG.warning( LOG.warning(
@@ -946,8 +955,16 @@ class OptimizationValidationMixin:
def check_batch_flattening_fa(cls, data): def check_batch_flattening_fa(cls, data):
if data.get("batch_flattening"): if data.get("batch_flattening"):
batch_flattening_auto = data.get("batch_flattening") == "auto" batch_flattening_auto = data.get("batch_flattening") == "auto"
if not data.get("flash_attention") and not batch_flattening_auto: has_varlen_attn = (
raise ValueError("batch_flattening requires flash attention") data.get("attn_implementation") not in _NON_PACKING_ATTN_IMPLS
if data.get("attn_implementation")
else data.get("flash_attention")
)
if not has_varlen_attn and not batch_flattening_auto:
raise ValueError(
"batch_flattening requires a varlen-capable attention backend "
"(e.g., attn_implementation: flash)"
)
if data.get("sample_packing") and not batch_flattening_auto: if data.get("sample_packing") and not batch_flattening_auto:
raise ValueError("batch_flattening not compatible with sample_packing") raise ValueError("batch_flattening not compatible with sample_packing")
if data.get("micro_batch_size") == 1 and not batch_flattening_auto: if data.get("micro_batch_size") == 1 and not batch_flattening_auto:
@@ -966,7 +983,7 @@ class OptimizationValidationMixin:
if ( if (
batch_flattening_auto batch_flattening_auto
and data.get("flash_attention") and has_varlen_attn
and not data.get("sample_packing") and not data.get("sample_packing")
and data.get("micro_batch_size") > 1 and data.get("micro_batch_size") > 1
): ):
@@ -1211,6 +1228,12 @@ class SystemValidationMixin:
def check_npu_config(cls, data): def check_npu_config(cls, data):
if is_torch_npu_available(): if is_torch_npu_available():
# check attention config # check attention config
unsupported_npu_impls = {"flash", "sdpa", "s2"}
attn_impl = data.get("attn_implementation")
if attn_impl and attn_impl in unsupported_npu_impls:
raise NotImplementedError(
f"attn_implementation={attn_impl!r} is currently not supported on Ascend NPU."
)
attn_list = ["flash_attention", "sdp_attention", "s2_attention"] attn_list = ["flash_attention", "sdp_attention", "s2_attention"]
for attn in attn_list: for attn in attn_list:
if data.get(attn): if data.get(attn):
@@ -1519,9 +1542,10 @@ class ComplexValidationMixin:
if not self.context_parallel_size: if not self.context_parallel_size:
self.context_parallel_size = 1 self.context_parallel_size = 1
elif self.context_parallel_size > 1: elif self.context_parallel_size > 1:
if not self.flash_attention: if not self.attn_uses_flash_lib:
raise ValueError( raise ValueError(
"flash_attention: true must be set with context_parallel_size > 1" "context_parallel_size > 1 requires flash attention "
"(attn_implementation: flash or s2)."
) )
if self.sample_packing and self.micro_batch_size > 1: if self.sample_packing and self.micro_batch_size > 1:
@@ -1658,7 +1682,9 @@ class EBFTValidationMixin:
if ( if (
data.get("rl") == "ebft" data.get("rl") == "ebft"
and data.get("ebft", {}).get("mode") == "strided" and data.get("ebft", {}).get("mode") == "strided"
and data.get("flex_attention") and (
data.get("flex_attention") or data.get("attn_implementation") == "flex"
)
and data.get("gradient_checkpointing") and data.get("gradient_checkpointing")
): ):
gc_kwargs = data.get("gradient_checkpointing_kwargs") or {} gc_kwargs = data.get("gradient_checkpointing_kwargs") or {}

View File

@@ -462,7 +462,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}" f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
) )
else: else:
if cfg.flash_attention and not cfg.multipack_real_batches: if cfg.attn_supports_packing and not cfg.multipack_real_batches:
sampler_batch_size = 1 sampler_batch_size = 1
batch_max_len = cfg.micro_batch_size * cfg.sequence_len batch_max_len = cfg.micro_batch_size * cfg.sequence_len
else: else:

View File

@@ -1,11 +1,17 @@
""" """
Tests for attn_implementation normalization, registry registration, and Tests for attn_implementation normalization, registry registration,
backwards compatibility with legacy boolean attention flags. capability properties, and backwards compatibility with legacy boolean
attention flags.
""" """
import pytest import pytest
from axolotl.utils.schemas.config import AxolotlInputConfig from axolotl.utils.schemas.config import AxolotlInputConfig
from axolotl.utils.schemas.enums import (
_NO_DTYPE_CAST_ATTN_IMPLS,
_NON_PACKING_ATTN_IMPLS,
FLASH_ATTN_LIB_IMPLS,
)
class TestAttnImplementationNormalizer: class TestAttnImplementationNormalizer:
@@ -18,22 +24,31 @@ class TestAttnImplementationNormalizer:
# --- Forward mapping: attn_implementation -> legacy flags --- # --- Forward mapping: attn_implementation -> legacy flags ---
@pytest.mark.parametrize( @pytest.mark.parametrize(
"impl,expected_flags", "impl,expected_flag",
[ [
("eager", {"eager_attention": True}), ("eager", "eager_attention"),
("flash", {"flash_attention": True}), ("flash", "flash_attention"),
("sdpa", {"sdp_attention": True}), ("sdpa", "sdp_attention"),
("flex", {"flex_attention": True}), ("flex", "flex_attention"),
("xformers", {"xformers_attention": True, "flash_attention": True}), ("xformers", "xformers_attention"),
("sage", {"sage_attention": True, "flash_attention": True}), ("sage", "sage_attention"),
("s2", {"s2_attention": True, "flash_attention": True}), ("s2", "s2_attention"),
], ],
) )
def test_attn_impl_sets_legacy_flags(self, impl, expected_flags): def test_attn_impl_sets_primary_legacy_flag(self, impl, expected_flag):
data = {"attn_implementation": impl} data = {"attn_implementation": impl}
result = AxolotlInputConfig.normalize_attn_implementation(data) result = AxolotlInputConfig.normalize_attn_implementation(data)
for flag, val in expected_flags.items(): assert result.get(expected_flag) is True, (
assert result.get(flag) == val, f"{impl}: expected {flag}={val}" f"{impl}: expected {expected_flag}=True"
)
@pytest.mark.parametrize("impl", ["xformers", "sage", "s2"])
def test_attn_impl_does_not_set_flash_for_non_flash(self, impl):
"""xformers, sage, s2 should NOT set flash_attention=True anymore."""
result = self._normalize({"attn_implementation": impl})
assert not result.get("flash_attention"), (
f"{impl} should not set flash_attention"
)
def test_fp8_sets_no_legacy_flags(self): def test_fp8_sets_no_legacy_flags(self):
result = self._normalize({"attn_implementation": "fp8"}) result = self._normalize({"attn_implementation": "fp8"})
@@ -87,27 +102,13 @@ class TestAttnImplementationNormalizer:
assert result["attn_implementation"] == "flash" assert result["attn_implementation"] == "flash"
assert result["flash_attention"] is True assert result["flash_attention"] is True
def test_consistent_xformers_with_extra_flags(self): def test_consistent_xformers_with_own_flag(self):
"""xformers needs flash_attention=True, so both flags with attn_impl should be OK.""" """xformers + xformers_attention should be OK."""
result = self._normalize( result = self._normalize(
{ {"attn_implementation": "xformers", "xformers_attention": True}
"attn_implementation": "xformers",
"xformers_attention": True,
"flash_attention": True,
}
) )
assert result["attn_implementation"] == "xformers" assert result["attn_implementation"] == "xformers"
def test_consistent_s2_with_flash(self):
result = self._normalize(
{
"attn_implementation": "s2",
"s2_attention": True,
"flash_attention": True,
}
)
assert result["attn_implementation"] == "s2"
# --- Conflict detection --- # --- Conflict detection ---
def test_conflicting_impl_and_flag_raises(self): def test_conflicting_impl_and_flag_raises(self):
@@ -118,6 +119,28 @@ class TestAttnImplementationNormalizer:
with pytest.raises(ValueError, match="conflicts with"): with pytest.raises(ValueError, match="conflicts with"):
self._normalize({"attn_implementation": "xformers", "sdp_attention": True}) self._normalize({"attn_implementation": "xformers", "sdp_attention": True})
def test_xformers_with_flash_flag_conflicts(self):
"""After normalizer change, xformers no longer expects flash_attention."""
with pytest.raises(ValueError, match="conflicts with"):
self._normalize(
{
"attn_implementation": "xformers",
"xformers_attention": True,
"flash_attention": True,
}
)
def test_s2_with_flash_flag_conflicts(self):
"""After normalizer change, s2 no longer expects flash_attention."""
with pytest.raises(ValueError, match="conflicts with"):
self._normalize(
{
"attn_implementation": "s2",
"s2_attention": True,
"flash_attention": True,
}
)
# --- Hub kernel strings pass through --- # --- Hub kernel strings pass through ---
def test_hub_kernel_passthrough(self): def test_hub_kernel_passthrough(self):
@@ -144,16 +167,69 @@ class TestAttnImplementationNormalizer:
result = self._normalize({"some_other_config": True}) result = self._normalize({"some_other_config": True})
assert result.get("attn_implementation") is None assert result.get("attn_implementation") is None
# --- Sample packing interactions --- # --- Gemma4 hybrid ---
def test_xformers_with_sample_packing_sets_flash(self): def test_gemma4_hybrid_sets_flash(self):
"""xformers + sample_packing needs flash_attention=True for the patch chain.""" """gemma4_hybrid_attn_impl should default attn_implementation to flash."""
result = self._normalize( result = self._normalize({"gemma4_hybrid_attn_impl": True})
{"attn_implementation": "xformers", "sample_packing": True} assert result["attn_implementation"] == "flash"
)
assert result["xformers_attention"] is True
assert result["flash_attention"] is True assert result["flash_attention"] is True
def test_gemma4_hybrid_does_not_override_explicit(self):
"""If attn_implementation is already set, gemma4 should not override it."""
result = self._normalize(
{"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"}
)
assert result["attn_implementation"] == "sdpa"
class TestAttnCapabilityProperties:
"""Test the capability properties on the normalizer data.
Since these are @property on AxolotlInputConfig (a Pydantic model),
we test the underlying logic directly using the constant sets.
"""
# --- attn_supports_packing ---
@pytest.mark.parametrize("impl", ["flash", "flex", "xformers", "sage"])
def test_supports_packing_true(self, impl):
assert impl not in _NON_PACKING_ATTN_IMPLS
@pytest.mark.parametrize("impl", ["eager", "sdpa", "s2", "fp8"])
def test_supports_packing_false(self, impl):
assert impl in _NON_PACKING_ATTN_IMPLS
def test_hub_kernel_supports_packing(self):
"""Unknown hub kernels should default to packing-capable."""
assert "kernels-community/flash-attn3" not in _NON_PACKING_ATTN_IMPLS
# --- attn_uses_flash_lib ---
@pytest.mark.parametrize("impl", ["flash", "s2"])
def test_uses_flash_lib_true(self, impl):
assert impl in FLASH_ATTN_LIB_IMPLS
@pytest.mark.parametrize(
"impl", ["eager", "sdpa", "xformers", "flex", "sage", "fp8"]
)
def test_uses_flash_lib_false(self, impl):
assert impl not in FLASH_ATTN_LIB_IMPLS
def test_hub_kernel_not_flash_lib(self):
"""Hub kernels are HF-managed, not axolotl monkeypatch targets."""
assert "kernels-community/flash-attn3" not in FLASH_ATTN_LIB_IMPLS
# --- attn_needs_dtype_cast ---
@pytest.mark.parametrize("impl", ["eager", "sdpa"])
def test_no_dtype_cast(self, impl):
assert impl in _NO_DTYPE_CAST_ATTN_IMPLS
@pytest.mark.parametrize("impl", ["flash", "flex", "sage", "xformers", "s2", "fp8"])
def test_needs_dtype_cast(self, impl):
assert impl not in _NO_DTYPE_CAST_ATTN_IMPLS
class TestAttnImplToHFMapping: class TestAttnImplToHFMapping:
"""Test that attn_implementation enum values map correctly to HF strings.""" """Test that attn_implementation enum values map correctly to HF strings."""