replace legacy attention boolean flags with capability properties
Replace checks with capability-based properties derived from attn_implementation This separates three concerns that were conflated under flash_attention: 1. Backend selection -> attn_implementation enum 2. Packing capability -> attn_supports_packing property 3. Flash-attn library dependency -> attn_uses_flash_lib property
This commit is contained in:
@@ -147,7 +147,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
load_in_8bit=False,
|
||||
load_in_4bit=False,
|
||||
quantize_moe_experts=False,
|
||||
flash_attention=False,
|
||||
attn_implementation=None,
|
||||
context_parallel_size=None,
|
||||
deepspeed=None,
|
||||
fsdp=None,
|
||||
|
||||
@@ -257,19 +257,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||
|
||||
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
||||
training_arguments_kwargs["sample_packing_drop_attention_mask"] = bool(
|
||||
self.cfg.flash_attention
|
||||
or self.cfg.xformers_attention
|
||||
or self.cfg.flex_attention
|
||||
training_arguments_kwargs["sample_packing_drop_attention_mask"] = (
|
||||
self.cfg.attn_supports_packing
|
||||
)
|
||||
training_arguments_kwargs["multipack_real_batches"] = (
|
||||
self.cfg.multipack_real_batches
|
||||
if self.cfg.multipack_real_batches is not None
|
||||
else not (
|
||||
self.cfg.flash_attention
|
||||
or self.cfg.flex_attention
|
||||
or self.cfg.xformers_attention
|
||||
)
|
||||
else not self.cfg.attn_supports_packing
|
||||
)
|
||||
training_arguments_kwargs["eval_sample_packing"] = bool(
|
||||
self.cfg.eval_sample_packing
|
||||
@@ -508,11 +502,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
|
||||
# supported multipack models, or non-flash-attention llama
|
||||
if (
|
||||
self.cfg.flex_attention
|
||||
self.cfg.attn_implementation == "flex"
|
||||
or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
or (
|
||||
self.cfg.model_config_type in ["llama"]
|
||||
and self.cfg.flash_attention is not True
|
||||
and self.cfg.attn_implementation != "flash"
|
||||
)
|
||||
):
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
|
||||
@@ -23,7 +23,7 @@ class LMEvalPlugin(BasePlugin):
|
||||
for lm_eval_args in build_lm_eval_command(
|
||||
cfg.lm_eval_tasks,
|
||||
bfloat16=cfg.bfloat16 or cfg.bf16,
|
||||
flash_attention=cfg.flash_attention,
|
||||
flash_attention=(cfg.attn_implementation == "flash"),
|
||||
output_dir=cfg.output_dir,
|
||||
batch_size=cfg.lm_eval_batch_size,
|
||||
wandb_project=cfg.wandb_project,
|
||||
|
||||
@@ -383,7 +383,9 @@ class SwanLabPlugin(BasePlugin):
|
||||
"seed": safe_convert(getattr(cfg, "seed", None)),
|
||||
"bf16": safe_convert(getattr(cfg, "bf16", None)),
|
||||
"tf32": safe_convert(getattr(cfg, "tf32", None)),
|
||||
"flash_attention": safe_convert(getattr(cfg, "flash_attention", None)),
|
||||
"attn_implementation": safe_convert(
|
||||
getattr(cfg, "attn_implementation", None)
|
||||
),
|
||||
"sample_packing": safe_convert(getattr(cfg, "sample_packing", None)),
|
||||
}
|
||||
|
||||
|
||||
@@ -343,12 +343,7 @@ class ModelLoader:
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so
|
||||
# we need to convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
(
|
||||
(
|
||||
needs_fa2_dtype
|
||||
or self.cfg.flash_attention
|
||||
or self.cfg.flex_attention
|
||||
or self.cfg.sage_attention
|
||||
)
|
||||
(needs_fa2_dtype or self.cfg.attn_needs_dtype_cast)
|
||||
and not self.is_qlora_and_fsdp_enabled
|
||||
)
|
||||
or (
|
||||
@@ -656,32 +651,12 @@ class ModelLoader:
|
||||
# global layers will be patched to sdpa post-load.
|
||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
self.model_config._attn_implementation = "flash_attention_2"
|
||||
# Set flash_attention so multipack/sample_packing patches activate
|
||||
self.cfg.flash_attention = True
|
||||
elif self.cfg.attn_implementation:
|
||||
hf_impl = _ATTN_IMPL_TO_HF.get(
|
||||
self.cfg.attn_implementation, self.cfg.attn_implementation
|
||||
)
|
||||
self.model_kwargs["attn_implementation"] = hf_impl
|
||||
self.model_config._attn_implementation = hf_impl
|
||||
elif self.cfg.flex_attention:
|
||||
self.model_kwargs["attn_implementation"] = "flex_attention"
|
||||
self.model_config._attn_implementation = "flex_attention"
|
||||
elif self.cfg.flash_attention:
|
||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||
pass
|
||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
self.model_config._attn_implementation = "flash_attention_2"
|
||||
elif self.cfg.sdp_attention:
|
||||
self.model_kwargs["attn_implementation"] = "sdpa"
|
||||
self.model_config._attn_implementation = "sdpa"
|
||||
elif self.cfg.sage_attention:
|
||||
# sets FA2 attention to re-use same internal handling like masking
|
||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
self.model_config._attn_implementation = "flash_attention_2"
|
||||
elif self.cfg.eager_attention:
|
||||
self.model_kwargs["attn_implementation"] = "eager"
|
||||
self.model_config._attn_implementation = "eager"
|
||||
|
||||
if self.cfg.low_cpu_mem_usage:
|
||||
self.model_kwargs["low_cpu_mem_usage"] = True
|
||||
|
||||
@@ -253,7 +253,7 @@ class PatchManager:
|
||||
|
||||
def _apply_flash_attention_patches(self):
|
||||
"""Apply patches related to Flash Attention."""
|
||||
if self.cfg.xformers_attention:
|
||||
if self.cfg.attn_implementation == "xformers":
|
||||
from axolotl.monkeypatch.attention import register_xformers_attn
|
||||
|
||||
register_xformers_attn()
|
||||
@@ -263,9 +263,8 @@ class PatchManager:
|
||||
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
|
||||
|
||||
patch_xformers_attn_over_fa2()
|
||||
self.cfg.flash_attention = True
|
||||
|
||||
if self.cfg.sage_attention:
|
||||
if self.cfg.attn_implementation == "sage":
|
||||
from axolotl.monkeypatch.attention import register_sage_attn
|
||||
|
||||
register_sage_attn()
|
||||
@@ -334,7 +333,7 @@ class PatchManager:
|
||||
|
||||
def _apply_flex_attention_patches(self):
|
||||
"""Apply patches for flexible attention."""
|
||||
if self.cfg.flex_attention:
|
||||
if self.cfg.attn_implementation == "flex":
|
||||
from axolotl.monkeypatch.attention.flex_attn import (
|
||||
patch_flex_wrapper,
|
||||
)
|
||||
@@ -344,14 +343,14 @@ class PatchManager:
|
||||
|
||||
def _apply_sageattn_patches(self):
|
||||
"""Apply patches for SageAttention."""
|
||||
if self.cfg.sage_attention:
|
||||
if self.cfg.attn_implementation == "sage":
|
||||
from axolotl.monkeypatch.attention.sage_attn import patch_sageattn
|
||||
|
||||
patch_sageattn()
|
||||
|
||||
def _apply_flash_attn_4_patches(self):
|
||||
"""Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+."""
|
||||
if not self.cfg.flash_attention:
|
||||
if not self.cfg.attn_uses_flash_lib:
|
||||
return
|
||||
|
||||
from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4
|
||||
@@ -420,7 +419,7 @@ class PatchManager:
|
||||
if (
|
||||
self.cfg.model_config_type in ["qwen3_5", "qwen3_5_moe"]
|
||||
and self.cfg.is_multimodal
|
||||
and self.cfg.flash_attention
|
||||
and self.cfg.attn_uses_flash_lib
|
||||
):
|
||||
from axolotl.monkeypatch.models.qwen3_5.modeling import (
|
||||
patch_qwen3_5_vlm_flash_attention,
|
||||
@@ -572,7 +571,7 @@ class PatchManager:
|
||||
"""Apply multipack patches if necessary."""
|
||||
if (
|
||||
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
and (self.cfg.flash_attention or self.cfg.flex_attention)
|
||||
and self.cfg.attn_supports_packing
|
||||
and self.cfg.sample_packing
|
||||
):
|
||||
# Get automap config if it exists
|
||||
@@ -693,7 +692,9 @@ class PatchManager:
|
||||
|
||||
def _patch_attention(self):
|
||||
"""Apply attention-specific patches based on model type."""
|
||||
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
|
||||
if not (
|
||||
self.cfg.attn_uses_flash_lib and hasattr(self.model_config, "model_type")
|
||||
):
|
||||
return
|
||||
|
||||
if self.model_config.model_type == "btlm":
|
||||
@@ -739,7 +740,7 @@ class PatchManager:
|
||||
replace_llama_attn_with_flash_attn,
|
||||
)
|
||||
|
||||
if self.cfg.s2_attention:
|
||||
if self.cfg.attn_implementation == "s2":
|
||||
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
|
||||
replace_llama_attn_with_flash_attn(
|
||||
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
||||
@@ -765,14 +766,14 @@ class PatchManager:
|
||||
"""Modify all llama derived models in one block."""
|
||||
if self.cfg.is_llama_derived_model and not (
|
||||
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
and (self.cfg.flash_attention or self.cfg.flex_attention)
|
||||
and self.cfg.attn_supports_packing
|
||||
and self.cfg.sample_packing
|
||||
):
|
||||
if self.cfg.flash_attention:
|
||||
if self.cfg.attn_uses_flash_lib:
|
||||
self._patch_llama_flash_attention()
|
||||
elif self.cfg.xformers_attention:
|
||||
elif self.cfg.attn_implementation == "xformers":
|
||||
self._patch_llama_xformers_attention()
|
||||
elif self.cfg.s2_attention:
|
||||
elif self.cfg.attn_implementation == "s2":
|
||||
raise NotImplementedError(
|
||||
"Shifted-sparse attention not currently implemented without flash attention."
|
||||
)
|
||||
@@ -784,7 +785,7 @@ class PatchManager:
|
||||
in ["llama", "llama4", "ernie4_5", "ernie4_5_moe"]
|
||||
and not self.cfg.trust_remote_code
|
||||
and not self.cfg.gptq
|
||||
and self.cfg.flash_attention
|
||||
and self.cfg.attn_uses_flash_lib
|
||||
and is_flash_attn_available()
|
||||
and not self.inference
|
||||
):
|
||||
|
||||
@@ -205,7 +205,11 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Mistral's official FA implementation requires left padding
|
||||
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
|
||||
if (
|
||||
cfg.is_mistral_derived_model
|
||||
and cfg.attn_implementation == "flash"
|
||||
and not cfg.sample_packing
|
||||
):
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# Qwen base only has single token, so we need to set the special tokens
|
||||
|
||||
@@ -955,9 +955,9 @@ def colab_inference_post_train_callback(trainer: Trainer):
|
||||
"""
|
||||
handle T4 gpu, we need to convert attention to eager for inference
|
||||
"""
|
||||
if "Tesla T4" in self.gpu_name and (
|
||||
self.cfg.xformers_attention
|
||||
or self.cfg.attn_implementation == "xformers"
|
||||
if (
|
||||
"Tesla T4" in self.gpu_name
|
||||
and self.cfg.attn_implementation == "xformers"
|
||||
):
|
||||
trainer.model.config._attn_implementation = "eager"
|
||||
trainer.model.gradient_checkpointing_disable()
|
||||
|
||||
@@ -28,6 +28,9 @@ from axolotl.utils.schemas.datasets import (
|
||||
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
|
||||
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
|
||||
from axolotl.utils.schemas.enums import (
|
||||
_NO_DTYPE_CAST_ATTN_IMPLS,
|
||||
_NON_PACKING_ATTN_IMPLS,
|
||||
FLASH_ATTN_LIB_IMPLS,
|
||||
AttnImplementation,
|
||||
ChatTemplate,
|
||||
RingAttnFunc,
|
||||
@@ -1332,6 +1335,40 @@ class AxolotlInputConfig(
|
||||
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
|
||||
return None
|
||||
|
||||
# --- Attention capability properties ---
|
||||
|
||||
@property
|
||||
def attn_supports_packing(self) -> bool:
|
||||
"""True if attention supports varlen sample packing via position_ids.
|
||||
|
||||
Known varlen backends: flash, flex, xformers, sage.
|
||||
Unknown strings (e.g., hub kernels like 'kernels-community/flash-attn3')
|
||||
default to True since they generally support varlen.
|
||||
"""
|
||||
if not self.attn_implementation:
|
||||
return False
|
||||
return self.attn_implementation not in _NON_PACKING_ATTN_IMPLS
|
||||
|
||||
@property
|
||||
def attn_uses_flash_lib(self) -> bool:
|
||||
"""True if the backend uses axolotl's flash_attn monkeypatches.
|
||||
|
||||
Only for axolotl-managed FA setup (flash, s2). Hub kernels are
|
||||
HF-managed and don't need these patches.
|
||||
"""
|
||||
return self.attn_implementation in FLASH_ATTN_LIB_IMPLS
|
||||
|
||||
@property
|
||||
def attn_needs_dtype_cast(self) -> bool:
|
||||
"""True if attention needs embedding dtype cast to fp16/bf16.
|
||||
|
||||
Unknown backends (hub kernels) default to True (safe -- harmless
|
||||
if unnecessary, but missing cast causes errors).
|
||||
"""
|
||||
if not self.attn_implementation:
|
||||
return False
|
||||
return self.attn_implementation not in _NO_DTYPE_CAST_ATTN_IMPLS
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def warn_peft_trainable_token_to_fix_untrained(cls, data):
|
||||
@@ -1358,16 +1395,22 @@ class AxolotlInputConfig(
|
||||
"""Normalize attention config: map between attn_implementation enum and legacy boolean flags."""
|
||||
attn_impl = data.get("attn_implementation")
|
||||
|
||||
# Mapping: attn_implementation value -> (primary flag, extra flags to set)
|
||||
impl_to_flags = {
|
||||
"eager": (("eager_attention",), ()),
|
||||
"flash": (("flash_attention",), ()),
|
||||
"sdpa": (("sdp_attention",), ()),
|
||||
"xformers": (("xformers_attention",), ("flash_attention",)),
|
||||
"flex": (("flex_attention",), ()),
|
||||
"sage": (("sage_attention",), ("flash_attention",)),
|
||||
"s2": (("s2_attention",), ("flash_attention",)),
|
||||
"fp8": ((), ()), # new, no legacy flags
|
||||
# If gemma4_hybrid_attn_impl is set but no attn_implementation, default
|
||||
# to flash (the sliding-window layers use FA2, and packing should be enabled).
|
||||
if data.get("gemma4_hybrid_attn_impl") and not attn_impl:
|
||||
data["attn_implementation"] = "flash"
|
||||
attn_impl = "flash"
|
||||
|
||||
# Mapping: attn_implementation value -> primary legacy flag to set
|
||||
impl_to_flag = {
|
||||
"eager": "eager_attention",
|
||||
"flash": "flash_attention",
|
||||
"sdpa": "sdp_attention",
|
||||
"xformers": "xformers_attention",
|
||||
"flex": "flex_attention",
|
||||
"sage": "sage_attention",
|
||||
"s2": "s2_attention",
|
||||
"fp8": None, # new, no legacy flag
|
||||
}
|
||||
|
||||
# Reverse mapping: legacy flag -> attn_implementation value
|
||||
@@ -1386,26 +1429,21 @@ class AxolotlInputConfig(
|
||||
|
||||
if attn_impl and set_flags:
|
||||
# Both set — check consistency
|
||||
if attn_impl in impl_to_flags:
|
||||
expected_primary, expected_extra = impl_to_flags[attn_impl]
|
||||
expected_flags = set(expected_primary) | set(expected_extra)
|
||||
for flag in set_flags:
|
||||
if flag not in expected_flags:
|
||||
raise ValueError(
|
||||
f"attn_implementation={attn_impl!r} conflicts with {flag}=true. "
|
||||
f"Use only attn_implementation or the legacy flag, not both."
|
||||
)
|
||||
expected_flag = impl_to_flag.get(attn_impl)
|
||||
for flag in set_flags:
|
||||
if flag != expected_flag:
|
||||
raise ValueError(
|
||||
f"attn_implementation={attn_impl!r} conflicts with {flag}=true. "
|
||||
f"Use only attn_implementation or the legacy flag, not both."
|
||||
)
|
||||
elif attn_impl and not set_flags:
|
||||
# attn_implementation set, no legacy flags — set them for backwards compat
|
||||
if attn_impl in impl_to_flags:
|
||||
primary, extra = impl_to_flags[attn_impl]
|
||||
for flag in (*primary, *extra):
|
||||
data[flag] = True
|
||||
# attn_implementation set, no legacy flags — set primary for backwards compat
|
||||
flag = impl_to_flag.get(attn_impl)
|
||||
if flag:
|
||||
data[flag] = True
|
||||
elif not attn_impl and set_flags:
|
||||
# Legacy flags set, no attn_implementation — map to enum, warn
|
||||
# Priority: specific backends first, then generic flash/sdp/eager
|
||||
# s2 and sage require flash_attention internally, so they must be
|
||||
# checked before flash_attention to avoid masking
|
||||
priority = [
|
||||
"xformers_attention",
|
||||
"s2_attention",
|
||||
@@ -1430,7 +1468,10 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sageattn_wo_sample_packing(cls, data):
|
||||
if (not data.get("sample_packing", False)) and data.get("sage_attention"):
|
||||
is_sage = (
|
||||
data.get("sage_attention") or data.get("attn_implementation") == "sage"
|
||||
)
|
||||
if (not data.get("sample_packing", False)) and is_sage:
|
||||
if not data.get("pad_to_sequence_len", False):
|
||||
LOG.warning(
|
||||
"We recommend turning on `pad_to_sequence_len` for SageAttention without packing."
|
||||
@@ -1441,7 +1482,10 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sageattn_fft(cls, data):
|
||||
if (not data.get("adapter", False)) and data.get("sage_attention"):
|
||||
is_sage = (
|
||||
data.get("sage_attention") or data.get("attn_implementation") == "sage"
|
||||
)
|
||||
if (not data.get("adapter", False)) and is_sage:
|
||||
LOG.warning(
|
||||
"We found loss to drop to 0 with SageAttention full finetuning."
|
||||
"Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method."
|
||||
@@ -1531,7 +1575,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
)
|
||||
if (
|
||||
data.get("sample_packing")
|
||||
and data.get("sdp_attention")
|
||||
and (data.get("sdp_attention") or data.get("attn_implementation") == "sdpa")
|
||||
and (data.get("bfloat16") or data.get("bf16"))
|
||||
and not is_sm_90
|
||||
):
|
||||
@@ -1546,8 +1590,11 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_compute_capability_w_sageattn(cls, data):
|
||||
is_sage = (
|
||||
data.get("sage_attention") or data.get("attn_implementation") == "sage"
|
||||
)
|
||||
if (
|
||||
data.get("sage_attention")
|
||||
is_sage
|
||||
and data.get("capabilities")
|
||||
and data.get("capabilities").get("compute_capability")
|
||||
not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"]
|
||||
@@ -1715,7 +1762,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_flex_torch_version(cls, data):
|
||||
if (data.get("flex_attention") is not None) and (data.get("flex_attention")):
|
||||
if data.get("flex_attention") or data.get("attn_implementation") == "flex":
|
||||
env_capabilities = data.get("env_capabilities", {})
|
||||
torch_version = env_capabilities.get("torch_version")
|
||||
|
||||
|
||||
@@ -110,6 +110,19 @@ class AttnImplementation(str, Enum):
|
||||
fp8 = "fp8" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Backends that require the flash_attn library (Dao-AILab/flash-attention)
|
||||
# for axolotl's own monkeypatches (FA4 auto-apply, LLaMA flash hijack, etc.)
|
||||
FLASH_ATTN_LIB_IMPLS = frozenset({"flash", "s2"})
|
||||
|
||||
# Known backends that do NOT support varlen sample packing via position_ids.
|
||||
# Used as an exclusion list: unknown strings (e.g., HF hub kernels like
|
||||
# "kernels-community/flash-attn3") default to packing-capable.
|
||||
_NON_PACKING_ATTN_IMPLS = frozenset({"eager", "sdpa", "s2", "fp8"})
|
||||
|
||||
# Known backends that do NOT need embedding dtype cast.
|
||||
_NO_DTYPE_CAST_ATTN_IMPLS = frozenset({"eager", "sdpa"})
|
||||
|
||||
|
||||
class RingAttnFunc(str, Enum):
|
||||
"""Enum class for supported `ring-flash-attn` implementations"""
|
||||
|
||||
|
||||
@@ -12,7 +12,12 @@ from pydantic import (
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
|
||||
from axolotl.utils.schemas.enums import (
|
||||
_NON_PACKING_ATTN_IMPLS,
|
||||
ChatTemplate,
|
||||
RingAttnFunc,
|
||||
RLType,
|
||||
)
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -182,6 +187,10 @@ class AttentionValidationMixin:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_attention_fields(cls, data):
|
||||
# If attn_implementation is set, the enum handles mutual exclusivity.
|
||||
# This validator catches legacy configs with multiple boolean flags.
|
||||
if data.get("attn_implementation"):
|
||||
return data
|
||||
fields = (
|
||||
"xformers_attention",
|
||||
"sdp_attention",
|
||||
@@ -436,7 +445,7 @@ class TrainingValidationMixin:
|
||||
not (self.bf16 or self.bfloat16)
|
||||
and (self.fp16 or self.float16)
|
||||
and not self.adapter
|
||||
and not self.flash_attention
|
||||
and not self.attn_uses_flash_lib
|
||||
and self.sample_packing
|
||||
):
|
||||
LOG.warning(
|
||||
@@ -946,8 +955,16 @@ class OptimizationValidationMixin:
|
||||
def check_batch_flattening_fa(cls, data):
|
||||
if data.get("batch_flattening"):
|
||||
batch_flattening_auto = data.get("batch_flattening") == "auto"
|
||||
if not data.get("flash_attention") and not batch_flattening_auto:
|
||||
raise ValueError("batch_flattening requires flash attention")
|
||||
has_varlen_attn = (
|
||||
data.get("attn_implementation") not in _NON_PACKING_ATTN_IMPLS
|
||||
if data.get("attn_implementation")
|
||||
else data.get("flash_attention")
|
||||
)
|
||||
if not has_varlen_attn and not batch_flattening_auto:
|
||||
raise ValueError(
|
||||
"batch_flattening requires a varlen-capable attention backend "
|
||||
"(e.g., attn_implementation: flash)"
|
||||
)
|
||||
if data.get("sample_packing") and not batch_flattening_auto:
|
||||
raise ValueError("batch_flattening not compatible with sample_packing")
|
||||
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:
|
||||
@@ -966,7 +983,7 @@ class OptimizationValidationMixin:
|
||||
|
||||
if (
|
||||
batch_flattening_auto
|
||||
and data.get("flash_attention")
|
||||
and has_varlen_attn
|
||||
and not data.get("sample_packing")
|
||||
and data.get("micro_batch_size") > 1
|
||||
):
|
||||
@@ -1211,6 +1228,12 @@ class SystemValidationMixin:
|
||||
def check_npu_config(cls, data):
|
||||
if is_torch_npu_available():
|
||||
# check attention config
|
||||
unsupported_npu_impls = {"flash", "sdpa", "s2"}
|
||||
attn_impl = data.get("attn_implementation")
|
||||
if attn_impl and attn_impl in unsupported_npu_impls:
|
||||
raise NotImplementedError(
|
||||
f"attn_implementation={attn_impl!r} is currently not supported on Ascend NPU."
|
||||
)
|
||||
attn_list = ["flash_attention", "sdp_attention", "s2_attention"]
|
||||
for attn in attn_list:
|
||||
if data.get(attn):
|
||||
@@ -1519,9 +1542,10 @@ class ComplexValidationMixin:
|
||||
if not self.context_parallel_size:
|
||||
self.context_parallel_size = 1
|
||||
elif self.context_parallel_size > 1:
|
||||
if not self.flash_attention:
|
||||
if not self.attn_uses_flash_lib:
|
||||
raise ValueError(
|
||||
"flash_attention: true must be set with context_parallel_size > 1"
|
||||
"context_parallel_size > 1 requires flash attention "
|
||||
"(attn_implementation: flash or s2)."
|
||||
)
|
||||
|
||||
if self.sample_packing and self.micro_batch_size > 1:
|
||||
@@ -1658,7 +1682,9 @@ class EBFTValidationMixin:
|
||||
if (
|
||||
data.get("rl") == "ebft"
|
||||
and data.get("ebft", {}).get("mode") == "strided"
|
||||
and data.get("flex_attention")
|
||||
and (
|
||||
data.get("flex_attention") or data.get("attn_implementation") == "flex"
|
||||
)
|
||||
and data.get("gradient_checkpointing")
|
||||
):
|
||||
gc_kwargs = data.get("gradient_checkpointing_kwargs") or {}
|
||||
|
||||
@@ -462,7 +462,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
|
||||
)
|
||||
else:
|
||||
if cfg.flash_attention and not cfg.multipack_real_batches:
|
||||
if cfg.attn_supports_packing and not cfg.multipack_real_batches:
|
||||
sampler_batch_size = 1
|
||||
batch_max_len = cfg.micro_batch_size * cfg.sequence_len
|
||||
else:
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
"""
|
||||
Tests for attn_implementation normalization, registry registration, and
|
||||
backwards compatibility with legacy boolean attention flags.
|
||||
Tests for attn_implementation normalization, registry registration,
|
||||
capability properties, and backwards compatibility with legacy boolean
|
||||
attention flags.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
from axolotl.utils.schemas.enums import (
|
||||
_NO_DTYPE_CAST_ATTN_IMPLS,
|
||||
_NON_PACKING_ATTN_IMPLS,
|
||||
FLASH_ATTN_LIB_IMPLS,
|
||||
)
|
||||
|
||||
|
||||
class TestAttnImplementationNormalizer:
|
||||
@@ -18,22 +24,31 @@ class TestAttnImplementationNormalizer:
|
||||
# --- Forward mapping: attn_implementation -> legacy flags ---
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"impl,expected_flags",
|
||||
"impl,expected_flag",
|
||||
[
|
||||
("eager", {"eager_attention": True}),
|
||||
("flash", {"flash_attention": True}),
|
||||
("sdpa", {"sdp_attention": True}),
|
||||
("flex", {"flex_attention": True}),
|
||||
("xformers", {"xformers_attention": True, "flash_attention": True}),
|
||||
("sage", {"sage_attention": True, "flash_attention": True}),
|
||||
("s2", {"s2_attention": True, "flash_attention": True}),
|
||||
("eager", "eager_attention"),
|
||||
("flash", "flash_attention"),
|
||||
("sdpa", "sdp_attention"),
|
||||
("flex", "flex_attention"),
|
||||
("xformers", "xformers_attention"),
|
||||
("sage", "sage_attention"),
|
||||
("s2", "s2_attention"),
|
||||
],
|
||||
)
|
||||
def test_attn_impl_sets_legacy_flags(self, impl, expected_flags):
|
||||
def test_attn_impl_sets_primary_legacy_flag(self, impl, expected_flag):
|
||||
data = {"attn_implementation": impl}
|
||||
result = AxolotlInputConfig.normalize_attn_implementation(data)
|
||||
for flag, val in expected_flags.items():
|
||||
assert result.get(flag) == val, f"{impl}: expected {flag}={val}"
|
||||
assert result.get(expected_flag) is True, (
|
||||
f"{impl}: expected {expected_flag}=True"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("impl", ["xformers", "sage", "s2"])
|
||||
def test_attn_impl_does_not_set_flash_for_non_flash(self, impl):
|
||||
"""xformers, sage, s2 should NOT set flash_attention=True anymore."""
|
||||
result = self._normalize({"attn_implementation": impl})
|
||||
assert not result.get("flash_attention"), (
|
||||
f"{impl} should not set flash_attention"
|
||||
)
|
||||
|
||||
def test_fp8_sets_no_legacy_flags(self):
|
||||
result = self._normalize({"attn_implementation": "fp8"})
|
||||
@@ -87,27 +102,13 @@ class TestAttnImplementationNormalizer:
|
||||
assert result["attn_implementation"] == "flash"
|
||||
assert result["flash_attention"] is True
|
||||
|
||||
def test_consistent_xformers_with_extra_flags(self):
|
||||
"""xformers needs flash_attention=True, so both flags with attn_impl should be OK."""
|
||||
def test_consistent_xformers_with_own_flag(self):
|
||||
"""xformers + xformers_attention should be OK."""
|
||||
result = self._normalize(
|
||||
{
|
||||
"attn_implementation": "xformers",
|
||||
"xformers_attention": True,
|
||||
"flash_attention": True,
|
||||
}
|
||||
{"attn_implementation": "xformers", "xformers_attention": True}
|
||||
)
|
||||
assert result["attn_implementation"] == "xformers"
|
||||
|
||||
def test_consistent_s2_with_flash(self):
|
||||
result = self._normalize(
|
||||
{
|
||||
"attn_implementation": "s2",
|
||||
"s2_attention": True,
|
||||
"flash_attention": True,
|
||||
}
|
||||
)
|
||||
assert result["attn_implementation"] == "s2"
|
||||
|
||||
# --- Conflict detection ---
|
||||
|
||||
def test_conflicting_impl_and_flag_raises(self):
|
||||
@@ -118,6 +119,28 @@ class TestAttnImplementationNormalizer:
|
||||
with pytest.raises(ValueError, match="conflicts with"):
|
||||
self._normalize({"attn_implementation": "xformers", "sdp_attention": True})
|
||||
|
||||
def test_xformers_with_flash_flag_conflicts(self):
|
||||
"""After normalizer change, xformers no longer expects flash_attention."""
|
||||
with pytest.raises(ValueError, match="conflicts with"):
|
||||
self._normalize(
|
||||
{
|
||||
"attn_implementation": "xformers",
|
||||
"xformers_attention": True,
|
||||
"flash_attention": True,
|
||||
}
|
||||
)
|
||||
|
||||
def test_s2_with_flash_flag_conflicts(self):
|
||||
"""After normalizer change, s2 no longer expects flash_attention."""
|
||||
with pytest.raises(ValueError, match="conflicts with"):
|
||||
self._normalize(
|
||||
{
|
||||
"attn_implementation": "s2",
|
||||
"s2_attention": True,
|
||||
"flash_attention": True,
|
||||
}
|
||||
)
|
||||
|
||||
# --- Hub kernel strings pass through ---
|
||||
|
||||
def test_hub_kernel_passthrough(self):
|
||||
@@ -144,16 +167,69 @@ class TestAttnImplementationNormalizer:
|
||||
result = self._normalize({"some_other_config": True})
|
||||
assert result.get("attn_implementation") is None
|
||||
|
||||
# --- Sample packing interactions ---
|
||||
# --- Gemma4 hybrid ---
|
||||
|
||||
def test_xformers_with_sample_packing_sets_flash(self):
|
||||
"""xformers + sample_packing needs flash_attention=True for the patch chain."""
|
||||
result = self._normalize(
|
||||
{"attn_implementation": "xformers", "sample_packing": True}
|
||||
)
|
||||
assert result["xformers_attention"] is True
|
||||
def test_gemma4_hybrid_sets_flash(self):
|
||||
"""gemma4_hybrid_attn_impl should default attn_implementation to flash."""
|
||||
result = self._normalize({"gemma4_hybrid_attn_impl": True})
|
||||
assert result["attn_implementation"] == "flash"
|
||||
assert result["flash_attention"] is True
|
||||
|
||||
def test_gemma4_hybrid_does_not_override_explicit(self):
|
||||
"""If attn_implementation is already set, gemma4 should not override it."""
|
||||
result = self._normalize(
|
||||
{"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"}
|
||||
)
|
||||
assert result["attn_implementation"] == "sdpa"
|
||||
|
||||
|
||||
class TestAttnCapabilityProperties:
|
||||
"""Test the capability properties on the normalizer data.
|
||||
|
||||
Since these are @property on AxolotlInputConfig (a Pydantic model),
|
||||
we test the underlying logic directly using the constant sets.
|
||||
"""
|
||||
|
||||
# --- attn_supports_packing ---
|
||||
|
||||
@pytest.mark.parametrize("impl", ["flash", "flex", "xformers", "sage"])
|
||||
def test_supports_packing_true(self, impl):
|
||||
assert impl not in _NON_PACKING_ATTN_IMPLS
|
||||
|
||||
@pytest.mark.parametrize("impl", ["eager", "sdpa", "s2", "fp8"])
|
||||
def test_supports_packing_false(self, impl):
|
||||
assert impl in _NON_PACKING_ATTN_IMPLS
|
||||
|
||||
def test_hub_kernel_supports_packing(self):
|
||||
"""Unknown hub kernels should default to packing-capable."""
|
||||
assert "kernels-community/flash-attn3" not in _NON_PACKING_ATTN_IMPLS
|
||||
|
||||
# --- attn_uses_flash_lib ---
|
||||
|
||||
@pytest.mark.parametrize("impl", ["flash", "s2"])
|
||||
def test_uses_flash_lib_true(self, impl):
|
||||
assert impl in FLASH_ATTN_LIB_IMPLS
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"impl", ["eager", "sdpa", "xformers", "flex", "sage", "fp8"]
|
||||
)
|
||||
def test_uses_flash_lib_false(self, impl):
|
||||
assert impl not in FLASH_ATTN_LIB_IMPLS
|
||||
|
||||
def test_hub_kernel_not_flash_lib(self):
|
||||
"""Hub kernels are HF-managed, not axolotl monkeypatch targets."""
|
||||
assert "kernels-community/flash-attn3" not in FLASH_ATTN_LIB_IMPLS
|
||||
|
||||
# --- attn_needs_dtype_cast ---
|
||||
|
||||
@pytest.mark.parametrize("impl", ["eager", "sdpa"])
|
||||
def test_no_dtype_cast(self, impl):
|
||||
assert impl in _NO_DTYPE_CAST_ATTN_IMPLS
|
||||
|
||||
@pytest.mark.parametrize("impl", ["flash", "flex", "sage", "xformers", "s2", "fp8"])
|
||||
def test_needs_dtype_cast(self, impl):
|
||||
assert impl not in _NO_DTYPE_CAST_ATTN_IMPLS
|
||||
|
||||
|
||||
class TestAttnImplToHFMapping:
|
||||
"""Test that attn_implementation enum values map correctly to HF strings."""
|
||||
|
||||
Reference in New Issue
Block a user