replace legacy attention boolean flags with capability properties
Replace checks with capability-based properties derived from attn_implementation This separates three concerns that were conflated under flash_attention: 1. Backend selection -> attn_implementation enum 2. Packing capability -> attn_supports_packing property 3. Flash-attn library dependency -> attn_uses_flash_lib property
This commit is contained in:
@@ -147,7 +147,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
|||||||
load_in_8bit=False,
|
load_in_8bit=False,
|
||||||
load_in_4bit=False,
|
load_in_4bit=False,
|
||||||
quantize_moe_experts=False,
|
quantize_moe_experts=False,
|
||||||
flash_attention=False,
|
attn_implementation=None,
|
||||||
context_parallel_size=None,
|
context_parallel_size=None,
|
||||||
deepspeed=None,
|
deepspeed=None,
|
||||||
fsdp=None,
|
fsdp=None,
|
||||||
|
|||||||
@@ -257,19 +257,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||||
|
|
||||||
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
||||||
training_arguments_kwargs["sample_packing_drop_attention_mask"] = bool(
|
training_arguments_kwargs["sample_packing_drop_attention_mask"] = (
|
||||||
self.cfg.flash_attention
|
self.cfg.attn_supports_packing
|
||||||
or self.cfg.xformers_attention
|
|
||||||
or self.cfg.flex_attention
|
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["multipack_real_batches"] = (
|
training_arguments_kwargs["multipack_real_batches"] = (
|
||||||
self.cfg.multipack_real_batches
|
self.cfg.multipack_real_batches
|
||||||
if self.cfg.multipack_real_batches is not None
|
if self.cfg.multipack_real_batches is not None
|
||||||
else not (
|
else not self.cfg.attn_supports_packing
|
||||||
self.cfg.flash_attention
|
|
||||||
or self.cfg.flex_attention
|
|
||||||
or self.cfg.xformers_attention
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["eval_sample_packing"] = bool(
|
training_arguments_kwargs["eval_sample_packing"] = bool(
|
||||||
self.cfg.eval_sample_packing
|
self.cfg.eval_sample_packing
|
||||||
@@ -508,11 +502,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
|
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
|
||||||
# supported multipack models, or non-flash-attention llama
|
# supported multipack models, or non-flash-attention llama
|
||||||
if (
|
if (
|
||||||
self.cfg.flex_attention
|
self.cfg.attn_implementation == "flex"
|
||||||
or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
or (
|
or (
|
||||||
self.cfg.model_config_type in ["llama"]
|
self.cfg.model_config_type in ["llama"]
|
||||||
and self.cfg.flash_attention is not True
|
and self.cfg.attn_implementation != "flash"
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class LMEvalPlugin(BasePlugin):
|
|||||||
for lm_eval_args in build_lm_eval_command(
|
for lm_eval_args in build_lm_eval_command(
|
||||||
cfg.lm_eval_tasks,
|
cfg.lm_eval_tasks,
|
||||||
bfloat16=cfg.bfloat16 or cfg.bf16,
|
bfloat16=cfg.bfloat16 or cfg.bf16,
|
||||||
flash_attention=cfg.flash_attention,
|
flash_attention=(cfg.attn_implementation == "flash"),
|
||||||
output_dir=cfg.output_dir,
|
output_dir=cfg.output_dir,
|
||||||
batch_size=cfg.lm_eval_batch_size,
|
batch_size=cfg.lm_eval_batch_size,
|
||||||
wandb_project=cfg.wandb_project,
|
wandb_project=cfg.wandb_project,
|
||||||
|
|||||||
@@ -383,7 +383,9 @@ class SwanLabPlugin(BasePlugin):
|
|||||||
"seed": safe_convert(getattr(cfg, "seed", None)),
|
"seed": safe_convert(getattr(cfg, "seed", None)),
|
||||||
"bf16": safe_convert(getattr(cfg, "bf16", None)),
|
"bf16": safe_convert(getattr(cfg, "bf16", None)),
|
||||||
"tf32": safe_convert(getattr(cfg, "tf32", None)),
|
"tf32": safe_convert(getattr(cfg, "tf32", None)),
|
||||||
"flash_attention": safe_convert(getattr(cfg, "flash_attention", None)),
|
"attn_implementation": safe_convert(
|
||||||
|
getattr(cfg, "attn_implementation", None)
|
||||||
|
),
|
||||||
"sample_packing": safe_convert(getattr(cfg, "sample_packing", None)),
|
"sample_packing": safe_convert(getattr(cfg, "sample_packing", None)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -343,12 +343,7 @@ class ModelLoader:
|
|||||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so
|
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so
|
||||||
# we need to convert them back to fp16/bf16 for flash-attn compatibility.
|
# we need to convert them back to fp16/bf16 for flash-attn compatibility.
|
||||||
(
|
(
|
||||||
(
|
(needs_fa2_dtype or self.cfg.attn_needs_dtype_cast)
|
||||||
needs_fa2_dtype
|
|
||||||
or self.cfg.flash_attention
|
|
||||||
or self.cfg.flex_attention
|
|
||||||
or self.cfg.sage_attention
|
|
||||||
)
|
|
||||||
and not self.is_qlora_and_fsdp_enabled
|
and not self.is_qlora_and_fsdp_enabled
|
||||||
)
|
)
|
||||||
or (
|
or (
|
||||||
@@ -656,32 +651,12 @@ class ModelLoader:
|
|||||||
# global layers will be patched to sdpa post-load.
|
# global layers will be patched to sdpa post-load.
|
||||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
self.model_config._attn_implementation = "flash_attention_2"
|
self.model_config._attn_implementation = "flash_attention_2"
|
||||||
# Set flash_attention so multipack/sample_packing patches activate
|
|
||||||
self.cfg.flash_attention = True
|
|
||||||
elif self.cfg.attn_implementation:
|
elif self.cfg.attn_implementation:
|
||||||
hf_impl = _ATTN_IMPL_TO_HF.get(
|
hf_impl = _ATTN_IMPL_TO_HF.get(
|
||||||
self.cfg.attn_implementation, self.cfg.attn_implementation
|
self.cfg.attn_implementation, self.cfg.attn_implementation
|
||||||
)
|
)
|
||||||
self.model_kwargs["attn_implementation"] = hf_impl
|
self.model_kwargs["attn_implementation"] = hf_impl
|
||||||
self.model_config._attn_implementation = hf_impl
|
self.model_config._attn_implementation = hf_impl
|
||||||
elif self.cfg.flex_attention:
|
|
||||||
self.model_kwargs["attn_implementation"] = "flex_attention"
|
|
||||||
self.model_config._attn_implementation = "flex_attention"
|
|
||||||
elif self.cfg.flash_attention:
|
|
||||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
|
||||||
pass
|
|
||||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
|
||||||
self.model_config._attn_implementation = "flash_attention_2"
|
|
||||||
elif self.cfg.sdp_attention:
|
|
||||||
self.model_kwargs["attn_implementation"] = "sdpa"
|
|
||||||
self.model_config._attn_implementation = "sdpa"
|
|
||||||
elif self.cfg.sage_attention:
|
|
||||||
# sets FA2 attention to re-use same internal handling like masking
|
|
||||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
|
||||||
self.model_config._attn_implementation = "flash_attention_2"
|
|
||||||
elif self.cfg.eager_attention:
|
|
||||||
self.model_kwargs["attn_implementation"] = "eager"
|
|
||||||
self.model_config._attn_implementation = "eager"
|
|
||||||
|
|
||||||
if self.cfg.low_cpu_mem_usage:
|
if self.cfg.low_cpu_mem_usage:
|
||||||
self.model_kwargs["low_cpu_mem_usage"] = True
|
self.model_kwargs["low_cpu_mem_usage"] = True
|
||||||
|
|||||||
@@ -253,7 +253,7 @@ class PatchManager:
|
|||||||
|
|
||||||
def _apply_flash_attention_patches(self):
|
def _apply_flash_attention_patches(self):
|
||||||
"""Apply patches related to Flash Attention."""
|
"""Apply patches related to Flash Attention."""
|
||||||
if self.cfg.xformers_attention:
|
if self.cfg.attn_implementation == "xformers":
|
||||||
from axolotl.monkeypatch.attention import register_xformers_attn
|
from axolotl.monkeypatch.attention import register_xformers_attn
|
||||||
|
|
||||||
register_xformers_attn()
|
register_xformers_attn()
|
||||||
@@ -263,9 +263,8 @@ class PatchManager:
|
|||||||
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
|
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
|
||||||
|
|
||||||
patch_xformers_attn_over_fa2()
|
patch_xformers_attn_over_fa2()
|
||||||
self.cfg.flash_attention = True
|
|
||||||
|
|
||||||
if self.cfg.sage_attention:
|
if self.cfg.attn_implementation == "sage":
|
||||||
from axolotl.monkeypatch.attention import register_sage_attn
|
from axolotl.monkeypatch.attention import register_sage_attn
|
||||||
|
|
||||||
register_sage_attn()
|
register_sage_attn()
|
||||||
@@ -334,7 +333,7 @@ class PatchManager:
|
|||||||
|
|
||||||
def _apply_flex_attention_patches(self):
|
def _apply_flex_attention_patches(self):
|
||||||
"""Apply patches for flexible attention."""
|
"""Apply patches for flexible attention."""
|
||||||
if self.cfg.flex_attention:
|
if self.cfg.attn_implementation == "flex":
|
||||||
from axolotl.monkeypatch.attention.flex_attn import (
|
from axolotl.monkeypatch.attention.flex_attn import (
|
||||||
patch_flex_wrapper,
|
patch_flex_wrapper,
|
||||||
)
|
)
|
||||||
@@ -344,14 +343,14 @@ class PatchManager:
|
|||||||
|
|
||||||
def _apply_sageattn_patches(self):
|
def _apply_sageattn_patches(self):
|
||||||
"""Apply patches for SageAttention."""
|
"""Apply patches for SageAttention."""
|
||||||
if self.cfg.sage_attention:
|
if self.cfg.attn_implementation == "sage":
|
||||||
from axolotl.monkeypatch.attention.sage_attn import patch_sageattn
|
from axolotl.monkeypatch.attention.sage_attn import patch_sageattn
|
||||||
|
|
||||||
patch_sageattn()
|
patch_sageattn()
|
||||||
|
|
||||||
def _apply_flash_attn_4_patches(self):
|
def _apply_flash_attn_4_patches(self):
|
||||||
"""Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+."""
|
"""Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+."""
|
||||||
if not self.cfg.flash_attention:
|
if not self.cfg.attn_uses_flash_lib:
|
||||||
return
|
return
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4
|
from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4
|
||||||
@@ -420,7 +419,7 @@ class PatchManager:
|
|||||||
if (
|
if (
|
||||||
self.cfg.model_config_type in ["qwen3_5", "qwen3_5_moe"]
|
self.cfg.model_config_type in ["qwen3_5", "qwen3_5_moe"]
|
||||||
and self.cfg.is_multimodal
|
and self.cfg.is_multimodal
|
||||||
and self.cfg.flash_attention
|
and self.cfg.attn_uses_flash_lib
|
||||||
):
|
):
|
||||||
from axolotl.monkeypatch.models.qwen3_5.modeling import (
|
from axolotl.monkeypatch.models.qwen3_5.modeling import (
|
||||||
patch_qwen3_5_vlm_flash_attention,
|
patch_qwen3_5_vlm_flash_attention,
|
||||||
@@ -572,7 +571,7 @@ class PatchManager:
|
|||||||
"""Apply multipack patches if necessary."""
|
"""Apply multipack patches if necessary."""
|
||||||
if (
|
if (
|
||||||
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
and (self.cfg.flash_attention or self.cfg.flex_attention)
|
and self.cfg.attn_supports_packing
|
||||||
and self.cfg.sample_packing
|
and self.cfg.sample_packing
|
||||||
):
|
):
|
||||||
# Get automap config if it exists
|
# Get automap config if it exists
|
||||||
@@ -693,7 +692,9 @@ class PatchManager:
|
|||||||
|
|
||||||
def _patch_attention(self):
|
def _patch_attention(self):
|
||||||
"""Apply attention-specific patches based on model type."""
|
"""Apply attention-specific patches based on model type."""
|
||||||
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
|
if not (
|
||||||
|
self.cfg.attn_uses_flash_lib and hasattr(self.model_config, "model_type")
|
||||||
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.model_config.model_type == "btlm":
|
if self.model_config.model_type == "btlm":
|
||||||
@@ -739,7 +740,7 @@ class PatchManager:
|
|||||||
replace_llama_attn_with_flash_attn,
|
replace_llama_attn_with_flash_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.s2_attention:
|
if self.cfg.attn_implementation == "s2":
|
||||||
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
|
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
|
||||||
replace_llama_attn_with_flash_attn(
|
replace_llama_attn_with_flash_attn(
|
||||||
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
||||||
@@ -765,14 +766,14 @@ class PatchManager:
|
|||||||
"""Modify all llama derived models in one block."""
|
"""Modify all llama derived models in one block."""
|
||||||
if self.cfg.is_llama_derived_model and not (
|
if self.cfg.is_llama_derived_model and not (
|
||||||
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
and (self.cfg.flash_attention or self.cfg.flex_attention)
|
and self.cfg.attn_supports_packing
|
||||||
and self.cfg.sample_packing
|
and self.cfg.sample_packing
|
||||||
):
|
):
|
||||||
if self.cfg.flash_attention:
|
if self.cfg.attn_uses_flash_lib:
|
||||||
self._patch_llama_flash_attention()
|
self._patch_llama_flash_attention()
|
||||||
elif self.cfg.xformers_attention:
|
elif self.cfg.attn_implementation == "xformers":
|
||||||
self._patch_llama_xformers_attention()
|
self._patch_llama_xformers_attention()
|
||||||
elif self.cfg.s2_attention:
|
elif self.cfg.attn_implementation == "s2":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Shifted-sparse attention not currently implemented without flash attention."
|
"Shifted-sparse attention not currently implemented without flash attention."
|
||||||
)
|
)
|
||||||
@@ -784,7 +785,7 @@ class PatchManager:
|
|||||||
in ["llama", "llama4", "ernie4_5", "ernie4_5_moe"]
|
in ["llama", "llama4", "ernie4_5", "ernie4_5_moe"]
|
||||||
and not self.cfg.trust_remote_code
|
and not self.cfg.trust_remote_code
|
||||||
and not self.cfg.gptq
|
and not self.cfg.gptq
|
||||||
and self.cfg.flash_attention
|
and self.cfg.attn_uses_flash_lib
|
||||||
and is_flash_attn_available()
|
and is_flash_attn_available()
|
||||||
and not self.inference
|
and not self.inference
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -205,7 +205,11 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
|||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
# Mistral's official FA implementation requires left padding
|
# Mistral's official FA implementation requires left padding
|
||||||
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
|
if (
|
||||||
|
cfg.is_mistral_derived_model
|
||||||
|
and cfg.attn_implementation == "flash"
|
||||||
|
and not cfg.sample_packing
|
||||||
|
):
|
||||||
tokenizer.padding_side = "left"
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
# Qwen base only has single token, so we need to set the special tokens
|
# Qwen base only has single token, so we need to set the special tokens
|
||||||
|
|||||||
@@ -955,9 +955,9 @@ def colab_inference_post_train_callback(trainer: Trainer):
|
|||||||
"""
|
"""
|
||||||
handle T4 gpu, we need to convert attention to eager for inference
|
handle T4 gpu, we need to convert attention to eager for inference
|
||||||
"""
|
"""
|
||||||
if "Tesla T4" in self.gpu_name and (
|
if (
|
||||||
self.cfg.xformers_attention
|
"Tesla T4" in self.gpu_name
|
||||||
or self.cfg.attn_implementation == "xformers"
|
and self.cfg.attn_implementation == "xformers"
|
||||||
):
|
):
|
||||||
trainer.model.config._attn_implementation = "eager"
|
trainer.model.config._attn_implementation = "eager"
|
||||||
trainer.model.gradient_checkpointing_disable()
|
trainer.model.gradient_checkpointing_disable()
|
||||||
|
|||||||
@@ -28,6 +28,9 @@ from axolotl.utils.schemas.datasets import (
|
|||||||
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
|
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
|
||||||
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
|
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
|
||||||
from axolotl.utils.schemas.enums import (
|
from axolotl.utils.schemas.enums import (
|
||||||
|
_NO_DTYPE_CAST_ATTN_IMPLS,
|
||||||
|
_NON_PACKING_ATTN_IMPLS,
|
||||||
|
FLASH_ATTN_LIB_IMPLS,
|
||||||
AttnImplementation,
|
AttnImplementation,
|
||||||
ChatTemplate,
|
ChatTemplate,
|
||||||
RingAttnFunc,
|
RingAttnFunc,
|
||||||
@@ -1332,6 +1335,40 @@ class AxolotlInputConfig(
|
|||||||
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
|
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# --- Attention capability properties ---
|
||||||
|
|
||||||
|
@property
|
||||||
|
def attn_supports_packing(self) -> bool:
|
||||||
|
"""True if attention supports varlen sample packing via position_ids.
|
||||||
|
|
||||||
|
Known varlen backends: flash, flex, xformers, sage.
|
||||||
|
Unknown strings (e.g., hub kernels like 'kernels-community/flash-attn3')
|
||||||
|
default to True since they generally support varlen.
|
||||||
|
"""
|
||||||
|
if not self.attn_implementation:
|
||||||
|
return False
|
||||||
|
return self.attn_implementation not in _NON_PACKING_ATTN_IMPLS
|
||||||
|
|
||||||
|
@property
|
||||||
|
def attn_uses_flash_lib(self) -> bool:
|
||||||
|
"""True if the backend uses axolotl's flash_attn monkeypatches.
|
||||||
|
|
||||||
|
Only for axolotl-managed FA setup (flash, s2). Hub kernels are
|
||||||
|
HF-managed and don't need these patches.
|
||||||
|
"""
|
||||||
|
return self.attn_implementation in FLASH_ATTN_LIB_IMPLS
|
||||||
|
|
||||||
|
@property
|
||||||
|
def attn_needs_dtype_cast(self) -> bool:
|
||||||
|
"""True if attention needs embedding dtype cast to fp16/bf16.
|
||||||
|
|
||||||
|
Unknown backends (hub kernels) default to True (safe -- harmless
|
||||||
|
if unnecessary, but missing cast causes errors).
|
||||||
|
"""
|
||||||
|
if not self.attn_implementation:
|
||||||
|
return False
|
||||||
|
return self.attn_implementation not in _NO_DTYPE_CAST_ATTN_IMPLS
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def warn_peft_trainable_token_to_fix_untrained(cls, data):
|
def warn_peft_trainable_token_to_fix_untrained(cls, data):
|
||||||
@@ -1358,16 +1395,22 @@ class AxolotlInputConfig(
|
|||||||
"""Normalize attention config: map between attn_implementation enum and legacy boolean flags."""
|
"""Normalize attention config: map between attn_implementation enum and legacy boolean flags."""
|
||||||
attn_impl = data.get("attn_implementation")
|
attn_impl = data.get("attn_implementation")
|
||||||
|
|
||||||
# Mapping: attn_implementation value -> (primary flag, extra flags to set)
|
# If gemma4_hybrid_attn_impl is set but no attn_implementation, default
|
||||||
impl_to_flags = {
|
# to flash (the sliding-window layers use FA2, and packing should be enabled).
|
||||||
"eager": (("eager_attention",), ()),
|
if data.get("gemma4_hybrid_attn_impl") and not attn_impl:
|
||||||
"flash": (("flash_attention",), ()),
|
data["attn_implementation"] = "flash"
|
||||||
"sdpa": (("sdp_attention",), ()),
|
attn_impl = "flash"
|
||||||
"xformers": (("xformers_attention",), ("flash_attention",)),
|
|
||||||
"flex": (("flex_attention",), ()),
|
# Mapping: attn_implementation value -> primary legacy flag to set
|
||||||
"sage": (("sage_attention",), ("flash_attention",)),
|
impl_to_flag = {
|
||||||
"s2": (("s2_attention",), ("flash_attention",)),
|
"eager": "eager_attention",
|
||||||
"fp8": ((), ()), # new, no legacy flags
|
"flash": "flash_attention",
|
||||||
|
"sdpa": "sdp_attention",
|
||||||
|
"xformers": "xformers_attention",
|
||||||
|
"flex": "flex_attention",
|
||||||
|
"sage": "sage_attention",
|
||||||
|
"s2": "s2_attention",
|
||||||
|
"fp8": None, # new, no legacy flag
|
||||||
}
|
}
|
||||||
|
|
||||||
# Reverse mapping: legacy flag -> attn_implementation value
|
# Reverse mapping: legacy flag -> attn_implementation value
|
||||||
@@ -1386,26 +1429,21 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
if attn_impl and set_flags:
|
if attn_impl and set_flags:
|
||||||
# Both set — check consistency
|
# Both set — check consistency
|
||||||
if attn_impl in impl_to_flags:
|
expected_flag = impl_to_flag.get(attn_impl)
|
||||||
expected_primary, expected_extra = impl_to_flags[attn_impl]
|
for flag in set_flags:
|
||||||
expected_flags = set(expected_primary) | set(expected_extra)
|
if flag != expected_flag:
|
||||||
for flag in set_flags:
|
raise ValueError(
|
||||||
if flag not in expected_flags:
|
f"attn_implementation={attn_impl!r} conflicts with {flag}=true. "
|
||||||
raise ValueError(
|
f"Use only attn_implementation or the legacy flag, not both."
|
||||||
f"attn_implementation={attn_impl!r} conflicts with {flag}=true. "
|
)
|
||||||
f"Use only attn_implementation or the legacy flag, not both."
|
|
||||||
)
|
|
||||||
elif attn_impl and not set_flags:
|
elif attn_impl and not set_flags:
|
||||||
# attn_implementation set, no legacy flags — set them for backwards compat
|
# attn_implementation set, no legacy flags — set primary for backwards compat
|
||||||
if attn_impl in impl_to_flags:
|
flag = impl_to_flag.get(attn_impl)
|
||||||
primary, extra = impl_to_flags[attn_impl]
|
if flag:
|
||||||
for flag in (*primary, *extra):
|
data[flag] = True
|
||||||
data[flag] = True
|
|
||||||
elif not attn_impl and set_flags:
|
elif not attn_impl and set_flags:
|
||||||
# Legacy flags set, no attn_implementation — map to enum, warn
|
# Legacy flags set, no attn_implementation — map to enum, warn
|
||||||
# Priority: specific backends first, then generic flash/sdp/eager
|
# Priority: specific backends first, then generic flash/sdp/eager
|
||||||
# s2 and sage require flash_attention internally, so they must be
|
|
||||||
# checked before flash_attention to avoid masking
|
|
||||||
priority = [
|
priority = [
|
||||||
"xformers_attention",
|
"xformers_attention",
|
||||||
"s2_attention",
|
"s2_attention",
|
||||||
@@ -1430,7 +1468,10 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_sageattn_wo_sample_packing(cls, data):
|
def check_sageattn_wo_sample_packing(cls, data):
|
||||||
if (not data.get("sample_packing", False)) and data.get("sage_attention"):
|
is_sage = (
|
||||||
|
data.get("sage_attention") or data.get("attn_implementation") == "sage"
|
||||||
|
)
|
||||||
|
if (not data.get("sample_packing", False)) and is_sage:
|
||||||
if not data.get("pad_to_sequence_len", False):
|
if not data.get("pad_to_sequence_len", False):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"We recommend turning on `pad_to_sequence_len` for SageAttention without packing."
|
"We recommend turning on `pad_to_sequence_len` for SageAttention without packing."
|
||||||
@@ -1441,7 +1482,10 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_sageattn_fft(cls, data):
|
def check_sageattn_fft(cls, data):
|
||||||
if (not data.get("adapter", False)) and data.get("sage_attention"):
|
is_sage = (
|
||||||
|
data.get("sage_attention") or data.get("attn_implementation") == "sage"
|
||||||
|
)
|
||||||
|
if (not data.get("adapter", False)) and is_sage:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"We found loss to drop to 0 with SageAttention full finetuning."
|
"We found loss to drop to 0 with SageAttention full finetuning."
|
||||||
"Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method."
|
"Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method."
|
||||||
@@ -1531,7 +1575,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
data.get("sample_packing")
|
data.get("sample_packing")
|
||||||
and data.get("sdp_attention")
|
and (data.get("sdp_attention") or data.get("attn_implementation") == "sdpa")
|
||||||
and (data.get("bfloat16") or data.get("bf16"))
|
and (data.get("bfloat16") or data.get("bf16"))
|
||||||
and not is_sm_90
|
and not is_sm_90
|
||||||
):
|
):
|
||||||
@@ -1546,8 +1590,11 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_compute_capability_w_sageattn(cls, data):
|
def check_compute_capability_w_sageattn(cls, data):
|
||||||
|
is_sage = (
|
||||||
|
data.get("sage_attention") or data.get("attn_implementation") == "sage"
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
data.get("sage_attention")
|
is_sage
|
||||||
and data.get("capabilities")
|
and data.get("capabilities")
|
||||||
and data.get("capabilities").get("compute_capability")
|
and data.get("capabilities").get("compute_capability")
|
||||||
not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"]
|
not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"]
|
||||||
@@ -1715,7 +1762,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_flex_torch_version(cls, data):
|
def check_flex_torch_version(cls, data):
|
||||||
if (data.get("flex_attention") is not None) and (data.get("flex_attention")):
|
if data.get("flex_attention") or data.get("attn_implementation") == "flex":
|
||||||
env_capabilities = data.get("env_capabilities", {})
|
env_capabilities = data.get("env_capabilities", {})
|
||||||
torch_version = env_capabilities.get("torch_version")
|
torch_version = env_capabilities.get("torch_version")
|
||||||
|
|
||||||
|
|||||||
@@ -110,6 +110,19 @@ class AttnImplementation(str, Enum):
|
|||||||
fp8 = "fp8" # pylint: disable=invalid-name
|
fp8 = "fp8" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
# Backends that require the flash_attn library (Dao-AILab/flash-attention)
|
||||||
|
# for axolotl's own monkeypatches (FA4 auto-apply, LLaMA flash hijack, etc.)
|
||||||
|
FLASH_ATTN_LIB_IMPLS = frozenset({"flash", "s2"})
|
||||||
|
|
||||||
|
# Known backends that do NOT support varlen sample packing via position_ids.
|
||||||
|
# Used as an exclusion list: unknown strings (e.g., HF hub kernels like
|
||||||
|
# "kernels-community/flash-attn3") default to packing-capable.
|
||||||
|
_NON_PACKING_ATTN_IMPLS = frozenset({"eager", "sdpa", "s2", "fp8"})
|
||||||
|
|
||||||
|
# Known backends that do NOT need embedding dtype cast.
|
||||||
|
_NO_DTYPE_CAST_ATTN_IMPLS = frozenset({"eager", "sdpa"})
|
||||||
|
|
||||||
|
|
||||||
class RingAttnFunc(str, Enum):
|
class RingAttnFunc(str, Enum):
|
||||||
"""Enum class for supported `ring-flash-attn` implementations"""
|
"""Enum class for supported `ring-flash-attn` implementations"""
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,12 @@ from pydantic import (
|
|||||||
from transformers.utils.import_utils import is_torch_npu_available
|
from transformers.utils.import_utils import is_torch_npu_available
|
||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
|
from axolotl.utils.schemas.enums import (
|
||||||
|
_NON_PACKING_ATTN_IMPLS,
|
||||||
|
ChatTemplate,
|
||||||
|
RingAttnFunc,
|
||||||
|
RLType,
|
||||||
|
)
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
@@ -182,6 +187,10 @@ class AttentionValidationMixin:
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_attention_fields(cls, data):
|
def check_attention_fields(cls, data):
|
||||||
|
# If attn_implementation is set, the enum handles mutual exclusivity.
|
||||||
|
# This validator catches legacy configs with multiple boolean flags.
|
||||||
|
if data.get("attn_implementation"):
|
||||||
|
return data
|
||||||
fields = (
|
fields = (
|
||||||
"xformers_attention",
|
"xformers_attention",
|
||||||
"sdp_attention",
|
"sdp_attention",
|
||||||
@@ -436,7 +445,7 @@ class TrainingValidationMixin:
|
|||||||
not (self.bf16 or self.bfloat16)
|
not (self.bf16 or self.bfloat16)
|
||||||
and (self.fp16 or self.float16)
|
and (self.fp16 or self.float16)
|
||||||
and not self.adapter
|
and not self.adapter
|
||||||
and not self.flash_attention
|
and not self.attn_uses_flash_lib
|
||||||
and self.sample_packing
|
and self.sample_packing
|
||||||
):
|
):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
@@ -946,8 +955,16 @@ class OptimizationValidationMixin:
|
|||||||
def check_batch_flattening_fa(cls, data):
|
def check_batch_flattening_fa(cls, data):
|
||||||
if data.get("batch_flattening"):
|
if data.get("batch_flattening"):
|
||||||
batch_flattening_auto = data.get("batch_flattening") == "auto"
|
batch_flattening_auto = data.get("batch_flattening") == "auto"
|
||||||
if not data.get("flash_attention") and not batch_flattening_auto:
|
has_varlen_attn = (
|
||||||
raise ValueError("batch_flattening requires flash attention")
|
data.get("attn_implementation") not in _NON_PACKING_ATTN_IMPLS
|
||||||
|
if data.get("attn_implementation")
|
||||||
|
else data.get("flash_attention")
|
||||||
|
)
|
||||||
|
if not has_varlen_attn and not batch_flattening_auto:
|
||||||
|
raise ValueError(
|
||||||
|
"batch_flattening requires a varlen-capable attention backend "
|
||||||
|
"(e.g., attn_implementation: flash)"
|
||||||
|
)
|
||||||
if data.get("sample_packing") and not batch_flattening_auto:
|
if data.get("sample_packing") and not batch_flattening_auto:
|
||||||
raise ValueError("batch_flattening not compatible with sample_packing")
|
raise ValueError("batch_flattening not compatible with sample_packing")
|
||||||
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:
|
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:
|
||||||
@@ -966,7 +983,7 @@ class OptimizationValidationMixin:
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
batch_flattening_auto
|
batch_flattening_auto
|
||||||
and data.get("flash_attention")
|
and has_varlen_attn
|
||||||
and not data.get("sample_packing")
|
and not data.get("sample_packing")
|
||||||
and data.get("micro_batch_size") > 1
|
and data.get("micro_batch_size") > 1
|
||||||
):
|
):
|
||||||
@@ -1211,6 +1228,12 @@ class SystemValidationMixin:
|
|||||||
def check_npu_config(cls, data):
|
def check_npu_config(cls, data):
|
||||||
if is_torch_npu_available():
|
if is_torch_npu_available():
|
||||||
# check attention config
|
# check attention config
|
||||||
|
unsupported_npu_impls = {"flash", "sdpa", "s2"}
|
||||||
|
attn_impl = data.get("attn_implementation")
|
||||||
|
if attn_impl and attn_impl in unsupported_npu_impls:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"attn_implementation={attn_impl!r} is currently not supported on Ascend NPU."
|
||||||
|
)
|
||||||
attn_list = ["flash_attention", "sdp_attention", "s2_attention"]
|
attn_list = ["flash_attention", "sdp_attention", "s2_attention"]
|
||||||
for attn in attn_list:
|
for attn in attn_list:
|
||||||
if data.get(attn):
|
if data.get(attn):
|
||||||
@@ -1519,9 +1542,10 @@ class ComplexValidationMixin:
|
|||||||
if not self.context_parallel_size:
|
if not self.context_parallel_size:
|
||||||
self.context_parallel_size = 1
|
self.context_parallel_size = 1
|
||||||
elif self.context_parallel_size > 1:
|
elif self.context_parallel_size > 1:
|
||||||
if not self.flash_attention:
|
if not self.attn_uses_flash_lib:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"flash_attention: true must be set with context_parallel_size > 1"
|
"context_parallel_size > 1 requires flash attention "
|
||||||
|
"(attn_implementation: flash or s2)."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.sample_packing and self.micro_batch_size > 1:
|
if self.sample_packing and self.micro_batch_size > 1:
|
||||||
@@ -1658,7 +1682,9 @@ class EBFTValidationMixin:
|
|||||||
if (
|
if (
|
||||||
data.get("rl") == "ebft"
|
data.get("rl") == "ebft"
|
||||||
and data.get("ebft", {}).get("mode") == "strided"
|
and data.get("ebft", {}).get("mode") == "strided"
|
||||||
and data.get("flex_attention")
|
and (
|
||||||
|
data.get("flex_attention") or data.get("attn_implementation") == "flex"
|
||||||
|
)
|
||||||
and data.get("gradient_checkpointing")
|
and data.get("gradient_checkpointing")
|
||||||
):
|
):
|
||||||
gc_kwargs = data.get("gradient_checkpointing_kwargs") or {}
|
gc_kwargs = data.get("gradient_checkpointing_kwargs") or {}
|
||||||
|
|||||||
@@ -462,7 +462,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
|
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if cfg.flash_attention and not cfg.multipack_real_batches:
|
if cfg.attn_supports_packing and not cfg.multipack_real_batches:
|
||||||
sampler_batch_size = 1
|
sampler_batch_size = 1
|
||||||
batch_max_len = cfg.micro_batch_size * cfg.sequence_len
|
batch_max_len = cfg.micro_batch_size * cfg.sequence_len
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,11 +1,17 @@
|
|||||||
"""
|
"""
|
||||||
Tests for attn_implementation normalization, registry registration, and
|
Tests for attn_implementation normalization, registry registration,
|
||||||
backwards compatibility with legacy boolean attention flags.
|
capability properties, and backwards compatibility with legacy boolean
|
||||||
|
attention flags.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
from axolotl.utils.schemas.enums import (
|
||||||
|
_NO_DTYPE_CAST_ATTN_IMPLS,
|
||||||
|
_NON_PACKING_ATTN_IMPLS,
|
||||||
|
FLASH_ATTN_LIB_IMPLS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestAttnImplementationNormalizer:
|
class TestAttnImplementationNormalizer:
|
||||||
@@ -18,22 +24,31 @@ class TestAttnImplementationNormalizer:
|
|||||||
# --- Forward mapping: attn_implementation -> legacy flags ---
|
# --- Forward mapping: attn_implementation -> legacy flags ---
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"impl,expected_flags",
|
"impl,expected_flag",
|
||||||
[
|
[
|
||||||
("eager", {"eager_attention": True}),
|
("eager", "eager_attention"),
|
||||||
("flash", {"flash_attention": True}),
|
("flash", "flash_attention"),
|
||||||
("sdpa", {"sdp_attention": True}),
|
("sdpa", "sdp_attention"),
|
||||||
("flex", {"flex_attention": True}),
|
("flex", "flex_attention"),
|
||||||
("xformers", {"xformers_attention": True, "flash_attention": True}),
|
("xformers", "xformers_attention"),
|
||||||
("sage", {"sage_attention": True, "flash_attention": True}),
|
("sage", "sage_attention"),
|
||||||
("s2", {"s2_attention": True, "flash_attention": True}),
|
("s2", "s2_attention"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_attn_impl_sets_legacy_flags(self, impl, expected_flags):
|
def test_attn_impl_sets_primary_legacy_flag(self, impl, expected_flag):
|
||||||
data = {"attn_implementation": impl}
|
data = {"attn_implementation": impl}
|
||||||
result = AxolotlInputConfig.normalize_attn_implementation(data)
|
result = AxolotlInputConfig.normalize_attn_implementation(data)
|
||||||
for flag, val in expected_flags.items():
|
assert result.get(expected_flag) is True, (
|
||||||
assert result.get(flag) == val, f"{impl}: expected {flag}={val}"
|
f"{impl}: expected {expected_flag}=True"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("impl", ["xformers", "sage", "s2"])
|
||||||
|
def test_attn_impl_does_not_set_flash_for_non_flash(self, impl):
|
||||||
|
"""xformers, sage, s2 should NOT set flash_attention=True anymore."""
|
||||||
|
result = self._normalize({"attn_implementation": impl})
|
||||||
|
assert not result.get("flash_attention"), (
|
||||||
|
f"{impl} should not set flash_attention"
|
||||||
|
)
|
||||||
|
|
||||||
def test_fp8_sets_no_legacy_flags(self):
|
def test_fp8_sets_no_legacy_flags(self):
|
||||||
result = self._normalize({"attn_implementation": "fp8"})
|
result = self._normalize({"attn_implementation": "fp8"})
|
||||||
@@ -87,27 +102,13 @@ class TestAttnImplementationNormalizer:
|
|||||||
assert result["attn_implementation"] == "flash"
|
assert result["attn_implementation"] == "flash"
|
||||||
assert result["flash_attention"] is True
|
assert result["flash_attention"] is True
|
||||||
|
|
||||||
def test_consistent_xformers_with_extra_flags(self):
|
def test_consistent_xformers_with_own_flag(self):
|
||||||
"""xformers needs flash_attention=True, so both flags with attn_impl should be OK."""
|
"""xformers + xformers_attention should be OK."""
|
||||||
result = self._normalize(
|
result = self._normalize(
|
||||||
{
|
{"attn_implementation": "xformers", "xformers_attention": True}
|
||||||
"attn_implementation": "xformers",
|
|
||||||
"xformers_attention": True,
|
|
||||||
"flash_attention": True,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
assert result["attn_implementation"] == "xformers"
|
assert result["attn_implementation"] == "xformers"
|
||||||
|
|
||||||
def test_consistent_s2_with_flash(self):
|
|
||||||
result = self._normalize(
|
|
||||||
{
|
|
||||||
"attn_implementation": "s2",
|
|
||||||
"s2_attention": True,
|
|
||||||
"flash_attention": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
assert result["attn_implementation"] == "s2"
|
|
||||||
|
|
||||||
# --- Conflict detection ---
|
# --- Conflict detection ---
|
||||||
|
|
||||||
def test_conflicting_impl_and_flag_raises(self):
|
def test_conflicting_impl_and_flag_raises(self):
|
||||||
@@ -118,6 +119,28 @@ class TestAttnImplementationNormalizer:
|
|||||||
with pytest.raises(ValueError, match="conflicts with"):
|
with pytest.raises(ValueError, match="conflicts with"):
|
||||||
self._normalize({"attn_implementation": "xformers", "sdp_attention": True})
|
self._normalize({"attn_implementation": "xformers", "sdp_attention": True})
|
||||||
|
|
||||||
|
def test_xformers_with_flash_flag_conflicts(self):
|
||||||
|
"""After normalizer change, xformers no longer expects flash_attention."""
|
||||||
|
with pytest.raises(ValueError, match="conflicts with"):
|
||||||
|
self._normalize(
|
||||||
|
{
|
||||||
|
"attn_implementation": "xformers",
|
||||||
|
"xformers_attention": True,
|
||||||
|
"flash_attention": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_s2_with_flash_flag_conflicts(self):
|
||||||
|
"""After normalizer change, s2 no longer expects flash_attention."""
|
||||||
|
with pytest.raises(ValueError, match="conflicts with"):
|
||||||
|
self._normalize(
|
||||||
|
{
|
||||||
|
"attn_implementation": "s2",
|
||||||
|
"s2_attention": True,
|
||||||
|
"flash_attention": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# --- Hub kernel strings pass through ---
|
# --- Hub kernel strings pass through ---
|
||||||
|
|
||||||
def test_hub_kernel_passthrough(self):
|
def test_hub_kernel_passthrough(self):
|
||||||
@@ -144,16 +167,69 @@ class TestAttnImplementationNormalizer:
|
|||||||
result = self._normalize({"some_other_config": True})
|
result = self._normalize({"some_other_config": True})
|
||||||
assert result.get("attn_implementation") is None
|
assert result.get("attn_implementation") is None
|
||||||
|
|
||||||
# --- Sample packing interactions ---
|
# --- Gemma4 hybrid ---
|
||||||
|
|
||||||
def test_xformers_with_sample_packing_sets_flash(self):
|
def test_gemma4_hybrid_sets_flash(self):
|
||||||
"""xformers + sample_packing needs flash_attention=True for the patch chain."""
|
"""gemma4_hybrid_attn_impl should default attn_implementation to flash."""
|
||||||
result = self._normalize(
|
result = self._normalize({"gemma4_hybrid_attn_impl": True})
|
||||||
{"attn_implementation": "xformers", "sample_packing": True}
|
assert result["attn_implementation"] == "flash"
|
||||||
)
|
|
||||||
assert result["xformers_attention"] is True
|
|
||||||
assert result["flash_attention"] is True
|
assert result["flash_attention"] is True
|
||||||
|
|
||||||
|
def test_gemma4_hybrid_does_not_override_explicit(self):
|
||||||
|
"""If attn_implementation is already set, gemma4 should not override it."""
|
||||||
|
result = self._normalize(
|
||||||
|
{"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"}
|
||||||
|
)
|
||||||
|
assert result["attn_implementation"] == "sdpa"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAttnCapabilityProperties:
|
||||||
|
"""Test the capability properties on the normalizer data.
|
||||||
|
|
||||||
|
Since these are @property on AxolotlInputConfig (a Pydantic model),
|
||||||
|
we test the underlying logic directly using the constant sets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# --- attn_supports_packing ---
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("impl", ["flash", "flex", "xformers", "sage"])
|
||||||
|
def test_supports_packing_true(self, impl):
|
||||||
|
assert impl not in _NON_PACKING_ATTN_IMPLS
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("impl", ["eager", "sdpa", "s2", "fp8"])
|
||||||
|
def test_supports_packing_false(self, impl):
|
||||||
|
assert impl in _NON_PACKING_ATTN_IMPLS
|
||||||
|
|
||||||
|
def test_hub_kernel_supports_packing(self):
|
||||||
|
"""Unknown hub kernels should default to packing-capable."""
|
||||||
|
assert "kernels-community/flash-attn3" not in _NON_PACKING_ATTN_IMPLS
|
||||||
|
|
||||||
|
# --- attn_uses_flash_lib ---
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("impl", ["flash", "s2"])
|
||||||
|
def test_uses_flash_lib_true(self, impl):
|
||||||
|
assert impl in FLASH_ATTN_LIB_IMPLS
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"impl", ["eager", "sdpa", "xformers", "flex", "sage", "fp8"]
|
||||||
|
)
|
||||||
|
def test_uses_flash_lib_false(self, impl):
|
||||||
|
assert impl not in FLASH_ATTN_LIB_IMPLS
|
||||||
|
|
||||||
|
def test_hub_kernel_not_flash_lib(self):
|
||||||
|
"""Hub kernels are HF-managed, not axolotl monkeypatch targets."""
|
||||||
|
assert "kernels-community/flash-attn3" not in FLASH_ATTN_LIB_IMPLS
|
||||||
|
|
||||||
|
# --- attn_needs_dtype_cast ---
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("impl", ["eager", "sdpa"])
|
||||||
|
def test_no_dtype_cast(self, impl):
|
||||||
|
assert impl in _NO_DTYPE_CAST_ATTN_IMPLS
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("impl", ["flash", "flex", "sage", "xformers", "s2", "fp8"])
|
||||||
|
def test_needs_dtype_cast(self, impl):
|
||||||
|
assert impl not in _NO_DTYPE_CAST_ATTN_IMPLS
|
||||||
|
|
||||||
|
|
||||||
class TestAttnImplToHFMapping:
|
class TestAttnImplToHFMapping:
|
||||||
"""Test that attn_implementation enum values map correctly to HF strings."""
|
"""Test that attn_implementation enum values map correctly to HF strings."""
|
||||||
|
|||||||
Reference in New Issue
Block a user