diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 00e5303bd..e111ca9f9 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -147,7 +147,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: load_in_8bit=False, load_in_4bit=False, quantize_moe_experts=False, - flash_attention=False, + attn_implementation=None, context_parallel_size=None, deepspeed=None, fsdp=None, diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index fe832dd45..d0b298bee 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -257,19 +257,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing) - training_arguments_kwargs["sample_packing_drop_attention_mask"] = bool( - self.cfg.flash_attention - or self.cfg.xformers_attention - or self.cfg.flex_attention + training_arguments_kwargs["sample_packing_drop_attention_mask"] = ( + self.cfg.attn_supports_packing ) training_arguments_kwargs["multipack_real_batches"] = ( self.cfg.multipack_real_batches if self.cfg.multipack_real_batches is not None - else not ( - self.cfg.flash_attention - or self.cfg.flex_attention - or self.cfg.xformers_attention - ) + else not self.cfg.attn_supports_packing ) training_arguments_kwargs["eval_sample_packing"] = bool( self.cfg.eval_sample_packing @@ -508,11 +502,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): # Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention, # supported multipack models, or non-flash-attention llama if ( - self.cfg.flex_attention + self.cfg.attn_implementation == "flex" or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or ( self.cfg.model_config_type in ["llama"] - and self.cfg.flash_attention is not True + and self.cfg.attn_implementation != "flash" ) ): collator = V2BatchSamplerDataCollatorForSeq2Seq diff --git a/src/axolotl/integrations/lm_eval/__init__.py b/src/axolotl/integrations/lm_eval/__init__.py index 6a82dd6cf..53386956a 100644 --- a/src/axolotl/integrations/lm_eval/__init__.py +++ b/src/axolotl/integrations/lm_eval/__init__.py @@ -23,7 +23,7 @@ class LMEvalPlugin(BasePlugin): for lm_eval_args in build_lm_eval_command( cfg.lm_eval_tasks, bfloat16=cfg.bfloat16 or cfg.bf16, - flash_attention=cfg.flash_attention, + flash_attention=(cfg.attn_implementation == "flash"), output_dir=cfg.output_dir, batch_size=cfg.lm_eval_batch_size, wandb_project=cfg.wandb_project, diff --git a/src/axolotl/integrations/swanlab/plugins.py b/src/axolotl/integrations/swanlab/plugins.py index 16218d39d..55f19ac59 100644 --- a/src/axolotl/integrations/swanlab/plugins.py +++ b/src/axolotl/integrations/swanlab/plugins.py @@ -383,7 +383,9 @@ class SwanLabPlugin(BasePlugin): "seed": safe_convert(getattr(cfg, "seed", None)), "bf16": safe_convert(getattr(cfg, "bf16", None)), "tf32": safe_convert(getattr(cfg, "tf32", None)), - "flash_attention": safe_convert(getattr(cfg, "flash_attention", None)), + "attn_implementation": safe_convert( + getattr(cfg, "attn_implementation", None) + ), "sample_packing": safe_convert(getattr(cfg, "sample_packing", None)), } diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index c2dbf00aa..0847c9b79 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -343,12 +343,7 @@ class ModelLoader: # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so # we need to convert them back to fp16/bf16 for flash-attn compatibility. ( - ( - needs_fa2_dtype - or self.cfg.flash_attention - or self.cfg.flex_attention - or self.cfg.sage_attention - ) + (needs_fa2_dtype or self.cfg.attn_needs_dtype_cast) and not self.is_qlora_and_fsdp_enabled ) or ( @@ -656,32 +651,12 @@ class ModelLoader: # global layers will be patched to sdpa post-load. self.model_kwargs["attn_implementation"] = "flash_attention_2" self.model_config._attn_implementation = "flash_attention_2" - # Set flash_attention so multipack/sample_packing patches activate - self.cfg.flash_attention = True elif self.cfg.attn_implementation: hf_impl = _ATTN_IMPL_TO_HF.get( self.cfg.attn_implementation, self.cfg.attn_implementation ) self.model_kwargs["attn_implementation"] = hf_impl self.model_config._attn_implementation = hf_impl - elif self.cfg.flex_attention: - self.model_kwargs["attn_implementation"] = "flex_attention" - self.model_config._attn_implementation = "flex_attention" - elif self.cfg.flash_attention: - if not self.cfg.sample_packing and self.cfg.s2_attention: - pass - self.model_kwargs["attn_implementation"] = "flash_attention_2" - self.model_config._attn_implementation = "flash_attention_2" - elif self.cfg.sdp_attention: - self.model_kwargs["attn_implementation"] = "sdpa" - self.model_config._attn_implementation = "sdpa" - elif self.cfg.sage_attention: - # sets FA2 attention to re-use same internal handling like masking - self.model_kwargs["attn_implementation"] = "flash_attention_2" - self.model_config._attn_implementation = "flash_attention_2" - elif self.cfg.eager_attention: - self.model_kwargs["attn_implementation"] = "eager" - self.model_config._attn_implementation = "eager" if self.cfg.low_cpu_mem_usage: self.model_kwargs["low_cpu_mem_usage"] = True diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index ebe0e6474..5783c5996 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -253,7 +253,7 @@ class PatchManager: def _apply_flash_attention_patches(self): """Apply patches related to Flash Attention.""" - if self.cfg.xformers_attention: + if self.cfg.attn_implementation == "xformers": from axolotl.monkeypatch.attention import register_xformers_attn register_xformers_attn() @@ -263,9 +263,8 @@ class PatchManager: from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2 patch_xformers_attn_over_fa2() - self.cfg.flash_attention = True - if self.cfg.sage_attention: + if self.cfg.attn_implementation == "sage": from axolotl.monkeypatch.attention import register_sage_attn register_sage_attn() @@ -334,7 +333,7 @@ class PatchManager: def _apply_flex_attention_patches(self): """Apply patches for flexible attention.""" - if self.cfg.flex_attention: + if self.cfg.attn_implementation == "flex": from axolotl.monkeypatch.attention.flex_attn import ( patch_flex_wrapper, ) @@ -344,14 +343,14 @@ class PatchManager: def _apply_sageattn_patches(self): """Apply patches for SageAttention.""" - if self.cfg.sage_attention: + if self.cfg.attn_implementation == "sage": from axolotl.monkeypatch.attention.sage_attn import patch_sageattn patch_sageattn() def _apply_flash_attn_4_patches(self): """Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+.""" - if not self.cfg.flash_attention: + if not self.cfg.attn_uses_flash_lib: return from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4 @@ -420,7 +419,7 @@ class PatchManager: if ( self.cfg.model_config_type in ["qwen3_5", "qwen3_5_moe"] and self.cfg.is_multimodal - and self.cfg.flash_attention + and self.cfg.attn_uses_flash_lib ): from axolotl.monkeypatch.models.qwen3_5.modeling import ( patch_qwen3_5_vlm_flash_attention, @@ -572,7 +571,7 @@ class PatchManager: """Apply multipack patches if necessary.""" if ( self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES - and (self.cfg.flash_attention or self.cfg.flex_attention) + and self.cfg.attn_supports_packing and self.cfg.sample_packing ): # Get automap config if it exists @@ -693,7 +692,9 @@ class PatchManager: def _patch_attention(self): """Apply attention-specific patches based on model type.""" - if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")): + if not ( + self.cfg.attn_uses_flash_lib and hasattr(self.model_config, "model_type") + ): return if self.model_config.model_type == "btlm": @@ -739,7 +740,7 @@ class PatchManager: replace_llama_attn_with_flash_attn, ) - if self.cfg.s2_attention: + if self.cfg.attn_implementation == "s2": LOG.info("patching w/ flash-enabled, shifted-sparse attention") replace_llama_attn_with_flash_attn( cross_entropy=self.cfg.flash_attn_cross_entropy, @@ -765,14 +766,14 @@ class PatchManager: """Modify all llama derived models in one block.""" if self.cfg.is_llama_derived_model and not ( self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES - and (self.cfg.flash_attention or self.cfg.flex_attention) + and self.cfg.attn_supports_packing and self.cfg.sample_packing ): - if self.cfg.flash_attention: + if self.cfg.attn_uses_flash_lib: self._patch_llama_flash_attention() - elif self.cfg.xformers_attention: + elif self.cfg.attn_implementation == "xformers": self._patch_llama_xformers_attention() - elif self.cfg.s2_attention: + elif self.cfg.attn_implementation == "s2": raise NotImplementedError( "Shifted-sparse attention not currently implemented without flash attention." ) @@ -784,7 +785,7 @@ class PatchManager: in ["llama", "llama4", "ernie4_5", "ernie4_5_moe"] and not self.cfg.trust_remote_code and not self.cfg.gptq - and self.cfg.flash_attention + and self.cfg.attn_uses_flash_lib and is_flash_attn_available() and not self.inference ): diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index 52f714604..48f4a9fa3 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -205,7 +205,11 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: os.environ["TOKENIZERS_PARALLELISM"] = "false" # Mistral's official FA implementation requires left padding - if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing: + if ( + cfg.is_mistral_derived_model + and cfg.attn_implementation == "flash" + and not cfg.sample_packing + ): tokenizer.padding_side = "left" # Qwen base only has single token, so we need to set the special tokens diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 5635e1261..edb61441b 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -955,9 +955,9 @@ def colab_inference_post_train_callback(trainer: Trainer): """ handle T4 gpu, we need to convert attention to eager for inference """ - if "Tesla T4" in self.gpu_name and ( - self.cfg.xformers_attention - or self.cfg.attn_implementation == "xformers" + if ( + "Tesla T4" in self.gpu_name + and self.cfg.attn_implementation == "xformers" ): trainer.model.config._attn_implementation = "eager" trainer.model.gradient_checkpointing_disable() diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 6c579efa5..b9278068c 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -28,6 +28,9 @@ from axolotl.utils.schemas.datasets import ( from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig from axolotl.utils.schemas.enums import ( + _NO_DTYPE_CAST_ATTN_IMPLS, + _NON_PACKING_ATTN_IMPLS, + FLASH_ATTN_LIB_IMPLS, AttnImplementation, ChatTemplate, RingAttnFunc, @@ -1332,6 +1335,40 @@ class AxolotlInputConfig( return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs] return None + # --- Attention capability properties --- + + @property + def attn_supports_packing(self) -> bool: + """True if attention supports varlen sample packing via position_ids. + + Known varlen backends: flash, flex, xformers, sage. + Unknown strings (e.g., hub kernels like 'kernels-community/flash-attn3') + default to True since they generally support varlen. + """ + if not self.attn_implementation: + return False + return self.attn_implementation not in _NON_PACKING_ATTN_IMPLS + + @property + def attn_uses_flash_lib(self) -> bool: + """True if the backend uses axolotl's flash_attn monkeypatches. + + Only for axolotl-managed FA setup (flash, s2). Hub kernels are + HF-managed and don't need these patches. + """ + return self.attn_implementation in FLASH_ATTN_LIB_IMPLS + + @property + def attn_needs_dtype_cast(self) -> bool: + """True if attention needs embedding dtype cast to fp16/bf16. + + Unknown backends (hub kernels) default to True (safe -- harmless + if unnecessary, but missing cast causes errors). + """ + if not self.attn_implementation: + return False + return self.attn_implementation not in _NO_DTYPE_CAST_ATTN_IMPLS + @model_validator(mode="before") @classmethod def warn_peft_trainable_token_to_fix_untrained(cls, data): @@ -1358,16 +1395,22 @@ class AxolotlInputConfig( """Normalize attention config: map between attn_implementation enum and legacy boolean flags.""" attn_impl = data.get("attn_implementation") - # Mapping: attn_implementation value -> (primary flag, extra flags to set) - impl_to_flags = { - "eager": (("eager_attention",), ()), - "flash": (("flash_attention",), ()), - "sdpa": (("sdp_attention",), ()), - "xformers": (("xformers_attention",), ("flash_attention",)), - "flex": (("flex_attention",), ()), - "sage": (("sage_attention",), ("flash_attention",)), - "s2": (("s2_attention",), ("flash_attention",)), - "fp8": ((), ()), # new, no legacy flags + # If gemma4_hybrid_attn_impl is set but no attn_implementation, default + # to flash (the sliding-window layers use FA2, and packing should be enabled). + if data.get("gemma4_hybrid_attn_impl") and not attn_impl: + data["attn_implementation"] = "flash" + attn_impl = "flash" + + # Mapping: attn_implementation value -> primary legacy flag to set + impl_to_flag = { + "eager": "eager_attention", + "flash": "flash_attention", + "sdpa": "sdp_attention", + "xformers": "xformers_attention", + "flex": "flex_attention", + "sage": "sage_attention", + "s2": "s2_attention", + "fp8": None, # new, no legacy flag } # Reverse mapping: legacy flag -> attn_implementation value @@ -1386,26 +1429,21 @@ class AxolotlInputConfig( if attn_impl and set_flags: # Both set — check consistency - if attn_impl in impl_to_flags: - expected_primary, expected_extra = impl_to_flags[attn_impl] - expected_flags = set(expected_primary) | set(expected_extra) - for flag in set_flags: - if flag not in expected_flags: - raise ValueError( - f"attn_implementation={attn_impl!r} conflicts with {flag}=true. " - f"Use only attn_implementation or the legacy flag, not both." - ) + expected_flag = impl_to_flag.get(attn_impl) + for flag in set_flags: + if flag != expected_flag: + raise ValueError( + f"attn_implementation={attn_impl!r} conflicts with {flag}=true. " + f"Use only attn_implementation or the legacy flag, not both." + ) elif attn_impl and not set_flags: - # attn_implementation set, no legacy flags — set them for backwards compat - if attn_impl in impl_to_flags: - primary, extra = impl_to_flags[attn_impl] - for flag in (*primary, *extra): - data[flag] = True + # attn_implementation set, no legacy flags — set primary for backwards compat + flag = impl_to_flag.get(attn_impl) + if flag: + data[flag] = True elif not attn_impl and set_flags: # Legacy flags set, no attn_implementation — map to enum, warn # Priority: specific backends first, then generic flash/sdp/eager - # s2 and sage require flash_attention internally, so they must be - # checked before flash_attention to avoid masking priority = [ "xformers_attention", "s2_attention", @@ -1430,7 +1468,10 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_sageattn_wo_sample_packing(cls, data): - if (not data.get("sample_packing", False)) and data.get("sage_attention"): + is_sage = ( + data.get("sage_attention") or data.get("attn_implementation") == "sage" + ) + if (not data.get("sample_packing", False)) and is_sage: if not data.get("pad_to_sequence_len", False): LOG.warning( "We recommend turning on `pad_to_sequence_len` for SageAttention without packing." @@ -1441,7 +1482,10 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_sageattn_fft(cls, data): - if (not data.get("adapter", False)) and data.get("sage_attention"): + is_sage = ( + data.get("sage_attention") or data.get("attn_implementation") == "sage" + ) + if (not data.get("adapter", False)) and is_sage: LOG.warning( "We found loss to drop to 0 with SageAttention full finetuning." "Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method." @@ -1531,7 +1575,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): ) if ( data.get("sample_packing") - and data.get("sdp_attention") + and (data.get("sdp_attention") or data.get("attn_implementation") == "sdpa") and (data.get("bfloat16") or data.get("bf16")) and not is_sm_90 ): @@ -1546,8 +1590,11 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): @model_validator(mode="before") @classmethod def check_compute_capability_w_sageattn(cls, data): + is_sage = ( + data.get("sage_attention") or data.get("attn_implementation") == "sage" + ) if ( - data.get("sage_attention") + is_sage and data.get("capabilities") and data.get("capabilities").get("compute_capability") not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"] @@ -1715,7 +1762,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): @model_validator(mode="before") @classmethod def check_flex_torch_version(cls, data): - if (data.get("flex_attention") is not None) and (data.get("flex_attention")): + if data.get("flex_attention") or data.get("attn_implementation") == "flex": env_capabilities = data.get("env_capabilities", {}) torch_version = env_capabilities.get("torch_version") diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 12d59c974..f01d4bd7a 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -110,6 +110,19 @@ class AttnImplementation(str, Enum): fp8 = "fp8" # pylint: disable=invalid-name +# Backends that require the flash_attn library (Dao-AILab/flash-attention) +# for axolotl's own monkeypatches (FA4 auto-apply, LLaMA flash hijack, etc.) +FLASH_ATTN_LIB_IMPLS = frozenset({"flash", "s2"}) + +# Known backends that do NOT support varlen sample packing via position_ids. +# Used as an exclusion list: unknown strings (e.g., HF hub kernels like +# "kernels-community/flash-attn3") default to packing-capable. +_NON_PACKING_ATTN_IMPLS = frozenset({"eager", "sdpa", "s2", "fp8"}) + +# Known backends that do NOT need embedding dtype cast. +_NO_DTYPE_CAST_ATTN_IMPLS = frozenset({"eager", "sdpa"}) + + class RingAttnFunc(str, Enum): """Enum class for supported `ring-flash-attn` implementations""" diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index b83396d4e..15ab26bfa 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -12,7 +12,12 @@ from pydantic import ( from transformers.utils.import_utils import is_torch_npu_available from axolotl.utils.logging import get_logger -from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType +from axolotl.utils.schemas.enums import ( + _NON_PACKING_ATTN_IMPLS, + ChatTemplate, + RingAttnFunc, + RLType, +) LOG = get_logger(__name__) @@ -182,6 +187,10 @@ class AttentionValidationMixin: @model_validator(mode="before") @classmethod def check_attention_fields(cls, data): + # If attn_implementation is set, the enum handles mutual exclusivity. + # This validator catches legacy configs with multiple boolean flags. + if data.get("attn_implementation"): + return data fields = ( "xformers_attention", "sdp_attention", @@ -436,7 +445,7 @@ class TrainingValidationMixin: not (self.bf16 or self.bfloat16) and (self.fp16 or self.float16) and not self.adapter - and not self.flash_attention + and not self.attn_uses_flash_lib and self.sample_packing ): LOG.warning( @@ -946,8 +955,16 @@ class OptimizationValidationMixin: def check_batch_flattening_fa(cls, data): if data.get("batch_flattening"): batch_flattening_auto = data.get("batch_flattening") == "auto" - if not data.get("flash_attention") and not batch_flattening_auto: - raise ValueError("batch_flattening requires flash attention") + has_varlen_attn = ( + data.get("attn_implementation") not in _NON_PACKING_ATTN_IMPLS + if data.get("attn_implementation") + else data.get("flash_attention") + ) + if not has_varlen_attn and not batch_flattening_auto: + raise ValueError( + "batch_flattening requires a varlen-capable attention backend " + "(e.g., attn_implementation: flash)" + ) if data.get("sample_packing") and not batch_flattening_auto: raise ValueError("batch_flattening not compatible with sample_packing") if data.get("micro_batch_size") == 1 and not batch_flattening_auto: @@ -966,7 +983,7 @@ class OptimizationValidationMixin: if ( batch_flattening_auto - and data.get("flash_attention") + and has_varlen_attn and not data.get("sample_packing") and data.get("micro_batch_size") > 1 ): @@ -1211,6 +1228,12 @@ class SystemValidationMixin: def check_npu_config(cls, data): if is_torch_npu_available(): # check attention config + unsupported_npu_impls = {"flash", "sdpa", "s2"} + attn_impl = data.get("attn_implementation") + if attn_impl and attn_impl in unsupported_npu_impls: + raise NotImplementedError( + f"attn_implementation={attn_impl!r} is currently not supported on Ascend NPU." + ) attn_list = ["flash_attention", "sdp_attention", "s2_attention"] for attn in attn_list: if data.get(attn): @@ -1519,9 +1542,10 @@ class ComplexValidationMixin: if not self.context_parallel_size: self.context_parallel_size = 1 elif self.context_parallel_size > 1: - if not self.flash_attention: + if not self.attn_uses_flash_lib: raise ValueError( - "flash_attention: true must be set with context_parallel_size > 1" + "context_parallel_size > 1 requires flash attention " + "(attn_implementation: flash or s2)." ) if self.sample_packing and self.micro_batch_size > 1: @@ -1658,7 +1682,9 @@ class EBFTValidationMixin: if ( data.get("rl") == "ebft" and data.get("ebft", {}).get("mode") == "strided" - and data.get("flex_attention") + and ( + data.get("flex_attention") or data.get("attn_implementation") == "flex" + ) and data.get("gradient_checkpointing") ): gc_kwargs = data.get("gradient_checkpointing_kwargs") or {} diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 91982137b..3fb940364 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -462,7 +462,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}" ) else: - if cfg.flash_attention and not cfg.multipack_real_batches: + if cfg.attn_supports_packing and not cfg.multipack_real_batches: sampler_batch_size = 1 batch_max_len = cfg.micro_batch_size * cfg.sequence_len else: diff --git a/tests/test_attn_implementation.py b/tests/test_attn_implementation.py index 71c052718..1973d5f74 100644 --- a/tests/test_attn_implementation.py +++ b/tests/test_attn_implementation.py @@ -1,11 +1,17 @@ """ -Tests for attn_implementation normalization, registry registration, and -backwards compatibility with legacy boolean attention flags. +Tests for attn_implementation normalization, registry registration, +capability properties, and backwards compatibility with legacy boolean +attention flags. """ import pytest from axolotl.utils.schemas.config import AxolotlInputConfig +from axolotl.utils.schemas.enums import ( + _NO_DTYPE_CAST_ATTN_IMPLS, + _NON_PACKING_ATTN_IMPLS, + FLASH_ATTN_LIB_IMPLS, +) class TestAttnImplementationNormalizer: @@ -18,22 +24,31 @@ class TestAttnImplementationNormalizer: # --- Forward mapping: attn_implementation -> legacy flags --- @pytest.mark.parametrize( - "impl,expected_flags", + "impl,expected_flag", [ - ("eager", {"eager_attention": True}), - ("flash", {"flash_attention": True}), - ("sdpa", {"sdp_attention": True}), - ("flex", {"flex_attention": True}), - ("xformers", {"xformers_attention": True, "flash_attention": True}), - ("sage", {"sage_attention": True, "flash_attention": True}), - ("s2", {"s2_attention": True, "flash_attention": True}), + ("eager", "eager_attention"), + ("flash", "flash_attention"), + ("sdpa", "sdp_attention"), + ("flex", "flex_attention"), + ("xformers", "xformers_attention"), + ("sage", "sage_attention"), + ("s2", "s2_attention"), ], ) - def test_attn_impl_sets_legacy_flags(self, impl, expected_flags): + def test_attn_impl_sets_primary_legacy_flag(self, impl, expected_flag): data = {"attn_implementation": impl} result = AxolotlInputConfig.normalize_attn_implementation(data) - for flag, val in expected_flags.items(): - assert result.get(flag) == val, f"{impl}: expected {flag}={val}" + assert result.get(expected_flag) is True, ( + f"{impl}: expected {expected_flag}=True" + ) + + @pytest.mark.parametrize("impl", ["xformers", "sage", "s2"]) + def test_attn_impl_does_not_set_flash_for_non_flash(self, impl): + """xformers, sage, s2 should NOT set flash_attention=True anymore.""" + result = self._normalize({"attn_implementation": impl}) + assert not result.get("flash_attention"), ( + f"{impl} should not set flash_attention" + ) def test_fp8_sets_no_legacy_flags(self): result = self._normalize({"attn_implementation": "fp8"}) @@ -87,27 +102,13 @@ class TestAttnImplementationNormalizer: assert result["attn_implementation"] == "flash" assert result["flash_attention"] is True - def test_consistent_xformers_with_extra_flags(self): - """xformers needs flash_attention=True, so both flags with attn_impl should be OK.""" + def test_consistent_xformers_with_own_flag(self): + """xformers + xformers_attention should be OK.""" result = self._normalize( - { - "attn_implementation": "xformers", - "xformers_attention": True, - "flash_attention": True, - } + {"attn_implementation": "xformers", "xformers_attention": True} ) assert result["attn_implementation"] == "xformers" - def test_consistent_s2_with_flash(self): - result = self._normalize( - { - "attn_implementation": "s2", - "s2_attention": True, - "flash_attention": True, - } - ) - assert result["attn_implementation"] == "s2" - # --- Conflict detection --- def test_conflicting_impl_and_flag_raises(self): @@ -118,6 +119,28 @@ class TestAttnImplementationNormalizer: with pytest.raises(ValueError, match="conflicts with"): self._normalize({"attn_implementation": "xformers", "sdp_attention": True}) + def test_xformers_with_flash_flag_conflicts(self): + """After normalizer change, xformers no longer expects flash_attention.""" + with pytest.raises(ValueError, match="conflicts with"): + self._normalize( + { + "attn_implementation": "xformers", + "xformers_attention": True, + "flash_attention": True, + } + ) + + def test_s2_with_flash_flag_conflicts(self): + """After normalizer change, s2 no longer expects flash_attention.""" + with pytest.raises(ValueError, match="conflicts with"): + self._normalize( + { + "attn_implementation": "s2", + "s2_attention": True, + "flash_attention": True, + } + ) + # --- Hub kernel strings pass through --- def test_hub_kernel_passthrough(self): @@ -144,16 +167,69 @@ class TestAttnImplementationNormalizer: result = self._normalize({"some_other_config": True}) assert result.get("attn_implementation") is None - # --- Sample packing interactions --- + # --- Gemma4 hybrid --- - def test_xformers_with_sample_packing_sets_flash(self): - """xformers + sample_packing needs flash_attention=True for the patch chain.""" - result = self._normalize( - {"attn_implementation": "xformers", "sample_packing": True} - ) - assert result["xformers_attention"] is True + def test_gemma4_hybrid_sets_flash(self): + """gemma4_hybrid_attn_impl should default attn_implementation to flash.""" + result = self._normalize({"gemma4_hybrid_attn_impl": True}) + assert result["attn_implementation"] == "flash" assert result["flash_attention"] is True + def test_gemma4_hybrid_does_not_override_explicit(self): + """If attn_implementation is already set, gemma4 should not override it.""" + result = self._normalize( + {"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"} + ) + assert result["attn_implementation"] == "sdpa" + + +class TestAttnCapabilityProperties: + """Test the capability properties on the normalizer data. + + Since these are @property on AxolotlInputConfig (a Pydantic model), + we test the underlying logic directly using the constant sets. + """ + + # --- attn_supports_packing --- + + @pytest.mark.parametrize("impl", ["flash", "flex", "xformers", "sage"]) + def test_supports_packing_true(self, impl): + assert impl not in _NON_PACKING_ATTN_IMPLS + + @pytest.mark.parametrize("impl", ["eager", "sdpa", "s2", "fp8"]) + def test_supports_packing_false(self, impl): + assert impl in _NON_PACKING_ATTN_IMPLS + + def test_hub_kernel_supports_packing(self): + """Unknown hub kernels should default to packing-capable.""" + assert "kernels-community/flash-attn3" not in _NON_PACKING_ATTN_IMPLS + + # --- attn_uses_flash_lib --- + + @pytest.mark.parametrize("impl", ["flash", "s2"]) + def test_uses_flash_lib_true(self, impl): + assert impl in FLASH_ATTN_LIB_IMPLS + + @pytest.mark.parametrize( + "impl", ["eager", "sdpa", "xformers", "flex", "sage", "fp8"] + ) + def test_uses_flash_lib_false(self, impl): + assert impl not in FLASH_ATTN_LIB_IMPLS + + def test_hub_kernel_not_flash_lib(self): + """Hub kernels are HF-managed, not axolotl monkeypatch targets.""" + assert "kernels-community/flash-attn3" not in FLASH_ATTN_LIB_IMPLS + + # --- attn_needs_dtype_cast --- + + @pytest.mark.parametrize("impl", ["eager", "sdpa"]) + def test_no_dtype_cast(self, impl): + assert impl in _NO_DTYPE_CAST_ATTN_IMPLS + + @pytest.mark.parametrize("impl", ["flash", "flex", "sage", "xformers", "s2", "fp8"]) + def test_needs_dtype_cast(self, impl): + assert impl not in _NO_DTYPE_CAST_ATTN_IMPLS + class TestAttnImplToHFMapping: """Test that attn_implementation enum values map correctly to HF strings."""