move attention-dependent validators to mode=after

This commit is contained in:
Wing Lian
2026-04-23 21:23:11 +00:00
parent 2579c496d5
commit bce65e3332
2 changed files with 164 additions and 159 deletions

View File

@@ -1474,32 +1474,27 @@ class AxolotlInputConfig(
f"path containing '/'." f"path containing '/'."
) )
@model_validator(mode="before") @model_validator(mode="after")
@classmethod def check_sageattn_wo_sample_packing(self):
def check_sageattn_wo_sample_packing(cls, data): if (
is_sage = ( self.attn_implementation == "sage"
data.get("sage_attention") or data.get("attn_implementation") == "sage" and not self.sample_packing
) and not self.pad_to_sequence_len
if (not data.get("sample_packing", False)) and is_sage: ):
if not data.get("pad_to_sequence_len", False):
LOG.warning(
"We recommend turning on `pad_to_sequence_len` for SageAttention without packing."
"This is because there has been signs that the loss explodes after a few steps."
)
return data
@model_validator(mode="before")
@classmethod
def check_sageattn_fft(cls, data):
is_sage = (
data.get("sage_attention") or data.get("attn_implementation") == "sage"
)
if (not data.get("adapter", False)) and is_sage:
LOG.warning( LOG.warning(
"We found loss to drop to 0 with SageAttention full finetuning." "We recommend turning on `pad_to_sequence_len` for SageAttention "
"Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method." "without packing. The loss has been observed to explode otherwise."
) )
return data return self
@model_validator(mode="after")
def check_sageattn_fft(self):
if self.attn_implementation == "sage" and not self.adapter:
LOG.warning(
"SageAttention full finetuning has been observed to drop loss to 0. "
"Monitor the loss, or switch to LoRA/QLoRA or another attention method."
)
return self
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@@ -1575,17 +1570,13 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
) )
return self return self
@model_validator(mode="before") @model_validator(mode="after")
@classmethod def check_sample_packing_w_sdpa_bf16(self):
def check_sample_packing_w_sdpa_bf16(cls, data): is_sm_90 = self.capabilities and self.capabilities.compute_capability == "sm_90"
is_sm_90: bool = (
data["capabilities"]
and data["capabilities"].get("compute_capability") == "sm_90"
)
if ( if (
data.get("sample_packing") self.sample_packing
and (data.get("sdp_attention") or data.get("attn_implementation") == "sdpa") and self.attn_implementation == "sdpa"
and (data.get("bfloat16") or data.get("bf16")) and (self.bfloat16 or self.bf16)
and not is_sm_90 and not is_sm_90
): ):
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450 # https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
@@ -1593,26 +1584,51 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. " "sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
"This may work on H100s." "This may work on H100s."
) )
return self
return data @model_validator(mode="after")
def check_compute_capability_w_sageattn(self):
@model_validator(mode="before")
@classmethod
def check_compute_capability_w_sageattn(cls, data):
is_sage = (
data.get("sage_attention") or data.get("attn_implementation") == "sage"
)
if ( if (
is_sage self.attn_implementation == "sage"
and data.get("capabilities") and self.capabilities
and data.get("capabilities").get("compute_capability") and self.capabilities.compute_capability
not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"] not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"]
): ):
raise ValueError( raise ValueError(
"SageAttention supports compute capability between sm_80 and sm_120. " "SageAttention supports compute capability between sm_80 and sm_120. "
"Please use a different attention implementation." "Please use a different attention implementation."
) )
return data return self
@model_validator(mode="after")
def check_fp8_attention_preflight(self):
"""fp8 attention requires SM90+ and torch >= 2.11 (torchao >= 0.17 is pinned)."""
if self.attn_implementation != "fp8":
return self
if self.capabilities and self.capabilities.compute_capability:
cc = self.capabilities.compute_capability
# Accept sm_90 (H100/H200), sm_100 (B100/B200), sm_120 (B300-class).
if not cc.startswith("sm_") or int(cc.split("_", 1)[1]) < 90:
raise ValueError(
f"attn_implementation=fp8 requires compute capability sm_90 or "
f"higher (Hopper+). Detected {cc!r}."
)
torch_version = (
self.env_capabilities.torch_version if self.env_capabilities else None
)
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.11.0"):
raise ValueError(
f"attn_implementation=fp8 requires PyTorch >= 2.11.0. "
f"Detected {torch_version}."
)
return self
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@@ -1768,16 +1784,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
) )
return data return data
@model_validator(mode="before") @model_validator(mode="after")
@classmethod def check_flex_torch_version(self):
def check_flex_torch_version(cls, data): if self.attn_implementation == "flex_attention":
if ( torch_version = (
data.get("flex_attention") self.env_capabilities.torch_version if self.env_capabilities else None
or data.get("attn_implementation") == "flex_attention" )
):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None: if torch_version is None:
import torch import torch
@@ -1787,7 +1799,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
raise ValueError( raise ValueError(
"Flex attention is not supported on torch version < 2.6.0" "Flex attention is not supported on torch version < 2.6.0"
) )
return data return self
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod

View File

@@ -13,7 +13,6 @@ from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import ( from axolotl.utils.schemas.enums import (
ATTN_IMPLS_SUPPORTING_PACKING,
ChatTemplate, ChatTemplate,
RingAttnFunc, RingAttnFunc,
RLType, RLType,
@@ -187,47 +186,42 @@ class AttentionValidationMixin:
# `check_attention_fields` was removed — `AxolotlInputConfig.normalize_attn_implementation` # `check_attention_fields` was removed — `AxolotlInputConfig.normalize_attn_implementation`
# is now the single entry point for attention-input mapping and conflict detection. # is now the single entry point for attention-input mapping and conflict detection.
@model_validator(mode="before") @model_validator(mode="after")
@classmethod def check_sample_packing_without_attention(self):
def check_sample_packing_without_attention(cls, data): if self.sample_packing and not self.attn_supports_packing:
if ( if self.attn_implementation:
data.get("sample_packing") LOG.warning(
and not data.get("attn_implementation") "`sample_packing` with `attn_implementation=%r` does not handle "
and not data.get("flash_attention") "cross-sample decontamination. Use a varlen-capable backend "
and not data.get("sdp_attention") "(e.g. flash_attention_2, flex_attention, xformers, sage) to "
and not data.get("flex_attention") "isolate samples.",
and not data.get("xformers_attention") self.attn_implementation,
and not data.get("sage_attention") )
): else:
LOG.warning( LOG.warning(
"sample_packing without flash, sdp, xformers, sage, or flex attention does not handle cross sample decontamination." "`sample_packing` without an attention backend does not handle "
) "cross-sample decontamination. Set `attn_implementation` to a "
return data "varlen-capable backend (e.g. flash_attention_2)."
)
return self
@model_validator(mode="before") @model_validator(mode="after")
@classmethod def check_sample_packing_with_s2attn(self):
def check_sample_packing_with_s2attn(cls, data): if self.sample_packing and self.attn_implementation == "s2":
if data.get("sample_packing") and (
data.get("s2_attention") or data.get("attn_implementation") == "s2"
):
raise ValueError( raise ValueError(
"Received `sample_packing=true` and `s2_attention=true`; however, \ "Received `sample_packing=true` and `attn_implementation=s2`; "
shifted-sparse attention does not currently support sample packing." "shifted-sparse attention does not currently support sample packing."
) )
return data return self
@model_validator(mode="before") @model_validator(mode="after")
@classmethod def check_scaling_softmax_requires_flex(self):
def check_scaling_softmax_requires_flex(cls, data): if self.scaling_softmax and self.attn_implementation != "flex_attention":
if data.get("scaling_softmax") and not (
data.get("flex_attention")
or data.get("attn_implementation") == "flex_attention"
):
raise ValueError( raise ValueError(
"scaling_softmax requires flex attention.\n" "scaling_softmax requires flex attention. "
"Add 'attn_implementation: flex' to your config file.\n" "Add `attn_implementation: flex_attention` to your config."
) )
return data return self
class TrainingValidationMixin: class TrainingValidationMixin:
@@ -933,48 +927,45 @@ class OptimizationValidationMixin:
) )
return data return data
@model_validator(mode="before") @model_validator(mode="after")
@classmethod def check_batch_flattening_fa(self):
def check_batch_flattening_fa(cls, data): if not self.batch_flattening:
if data.get("batch_flattening"): return self
batch_flattening_auto = data.get("batch_flattening") == "auto"
has_varlen_attn = ( batch_flattening_auto = self.batch_flattening == "auto"
data.get("attn_implementation") in ATTN_IMPLS_SUPPORTING_PACKING has_varlen_attn = self.attn_supports_packing
if data.get("attn_implementation")
else data.get("flash_attention") if not has_varlen_attn and not batch_flattening_auto:
raise ValueError(
"batch_flattening requires a varlen-capable attention backend "
"(e.g., attn_implementation: flash_attention_2)."
) )
if not has_varlen_attn and not batch_flattening_auto: if self.sample_packing and not batch_flattening_auto:
raise ValueError( raise ValueError("batch_flattening not compatible with sample_packing")
"batch_flattening requires a varlen-capable attention backend " if self.micro_batch_size == 1 and not batch_flattening_auto:
"(e.g., attn_implementation: flash)" LOG.warning("batch_flattening has no effect with micro_batch_size == 1")
)
if data.get("sample_packing") and not batch_flattening_auto:
raise ValueError("batch_flattening not compatible with sample_packing")
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:
LOG.warning("batch_flattening has no effect with micro_batch_size == 1")
# Liger loss takes a separate code path (compute_liger_loss) that # Liger loss takes a separate code path (compute_liger_loss) that
# bypasses the flattened training forward pass. Batch flattening # bypasses the flattened training forward pass. Batch flattening
# still applies to the scoring/deferred logprobs path. # still applies to the scoring/deferred logprobs path.
trl_cfg = data.get("trl") or {} if self.trl and getattr(self.trl, "use_liger_loss", False):
if isinstance(trl_cfg, dict) and trl_cfg.get("use_liger_loss"): LOG.warning(
LOG.warning( "batch_flattening with use_liger_loss: flattening will only "
"batch_flattening with use_liger_loss: flattening will only " "apply to the scoring path (deferred logprobs). The training "
"apply to the scoring path (deferred logprobs). The training " "forward pass uses Liger's fused lm_head+loss kernel instead."
"forward pass uses Liger's fused lm_head+loss kernel instead." )
)
if ( if (
batch_flattening_auto batch_flattening_auto
and has_varlen_attn and has_varlen_attn
and not data.get("sample_packing") and not self.sample_packing
and data.get("micro_batch_size") > 1 and self.micro_batch_size > 1
): ):
data["batch_flattening"] = True self.batch_flattening = True
elif batch_flattening_auto: elif batch_flattening_auto:
data["batch_flattening"] = False self.batch_flattening = False
return data return self
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@@ -1211,12 +1202,18 @@ class SystemValidationMixin:
def check_npu_config(cls, data): def check_npu_config(cls, data):
if is_torch_npu_available(): if is_torch_npu_available():
# check attention config # check attention config
unsupported_npu_impls = {"flash", "sdpa", "s2"} unsupported_npu_impls = {
"flash_attention_2",
"flash_attention_3",
"sdpa",
"s2",
}
attn_impl = data.get("attn_implementation") attn_impl = data.get("attn_implementation")
if attn_impl and attn_impl in unsupported_npu_impls: if attn_impl and attn_impl in unsupported_npu_impls:
raise NotImplementedError( raise NotImplementedError(
f"attn_implementation={attn_impl!r} is currently not supported on Ascend NPU." f"attn_implementation={attn_impl!r} is currently not supported on Ascend NPU."
) )
# Legacy flags still present at this point (normalizer strips them later).
attn_list = ["flash_attention", "sdp_attention", "s2_attention"] attn_list = ["flash_attention", "sdp_attention", "s2_attention"]
for attn in attn_list: for attn in attn_list:
if data.get(attn): if data.get(attn):
@@ -1658,50 +1655,46 @@ class EBFTValidationMixin:
) )
return data return data
@model_validator(mode="before") @model_validator(mode="after")
@classmethod def check_ebft_gradient_checkpointing_reentrant(self):
def check_ebft_gradient_checkpointing_reentrant(cls, data):
"""flex_attention + non-reentrant gradient checkpointing causes CheckpointError.""" """flex_attention + non-reentrant gradient checkpointing causes CheckpointError."""
if ( if (
data.get("rl") == "ebft" self.rl == "ebft"
and data.get("ebft", {}).get("mode") == "strided" and (self.ebft or {}).get("mode") == "strided"
and ( and self.attn_implementation == "flex_attention"
data.get("flex_attention") and self.gradient_checkpointing
or data.get("attn_implementation") == "flex_attention"
)
and data.get("gradient_checkpointing")
): ):
gc_kwargs = data.get("gradient_checkpointing_kwargs") or {} gc_kwargs = self.gradient_checkpointing_kwargs or {}
if not gc_kwargs.get("use_reentrant"): if not gc_kwargs.get("use_reentrant"):
LOG.warning( LOG.warning(
"EBFT strided mode with flex_attention: setting `use_reentrant: true` in " "EBFT strided mode with flex_attention: setting `use_reentrant: true` in "
"gradient_checkpointing_kwargs (required for flex_attention compatibility). " "gradient_checkpointing_kwargs (required for flex_attention compatibility). "
"Non-reentrant checkpointing causes CheckpointError with BlockMask metadata." "Non-reentrant checkpointing causes CheckpointError with BlockMask metadata."
) )
if data.get("gradient_checkpointing_kwargs") is None: if self.gradient_checkpointing_kwargs is None:
data["gradient_checkpointing_kwargs"] = {} self.gradient_checkpointing_kwargs = {}
data["gradient_checkpointing_kwargs"]["use_reentrant"] = True self.gradient_checkpointing_kwargs["use_reentrant"] = True
return data return self
@model_validator(mode="before") @model_validator(mode="after")
@classmethod def check_ebft_activation_offloading(self):
def check_ebft_activation_offloading(cls, data):
"""activation_offloading replaces gradient checkpointing with FSDP-style wrapping, """activation_offloading replaces gradient checkpointing with FSDP-style wrapping,
which conflicts with flex_attention's use_reentrant requirement.""" which conflicts with flex_attention's use_reentrant requirement."""
if ( if (
data.get("rl") == "ebft" self.rl == "ebft"
and data.get("ebft", {}).get("mode") == "strided" and (self.ebft or {}).get("mode") == "strided"
and data.get("activation_offloading") is True and self.activation_offloading is True
and data.get("flex_attention") and self.attn_implementation == "flex_attention"
): ):
raise ValueError( raise ValueError(
"EBFT strided mode: `activation_offloading: true` is incompatible with " "EBFT strided mode: `activation_offloading: true` is incompatible with "
"`flex_attention: true`. Activation offloading replaces gradient checkpointing " "`attn_implementation: flex_attention`. Activation offloading replaces "
"with FSDP-style wrapping that conflicts with flex_attention's reentrant " "gradient checkpointing with FSDP-style wrapping that conflicts with "
"checkpoint requirement. Remove `activation_offloading` — the strided trainer " "flex_attention's reentrant checkpoint requirement. Remove "
"uses micro-batched forward passes for memory efficiency instead." "`activation_offloading` — the strided trainer uses micro-batched forward "
"passes for memory efficiency instead."
) )
return data return self
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod