move attention-dependent validators to mode=after

This commit is contained in:
Wing Lian
2026-04-23 21:23:11 +00:00
parent 2579c496d5
commit bce65e3332
2 changed files with 164 additions and 159 deletions

View File

@@ -1474,32 +1474,27 @@ class AxolotlInputConfig(
f"path containing '/'."
)
@model_validator(mode="before")
@classmethod
def check_sageattn_wo_sample_packing(cls, data):
is_sage = (
data.get("sage_attention") or data.get("attn_implementation") == "sage"
)
if (not data.get("sample_packing", False)) and is_sage:
if not data.get("pad_to_sequence_len", False):
LOG.warning(
"We recommend turning on `pad_to_sequence_len` for SageAttention without packing."
"This is because there has been signs that the loss explodes after a few steps."
)
return data
@model_validator(mode="before")
@classmethod
def check_sageattn_fft(cls, data):
is_sage = (
data.get("sage_attention") or data.get("attn_implementation") == "sage"
)
if (not data.get("adapter", False)) and is_sage:
@model_validator(mode="after")
def check_sageattn_wo_sample_packing(self):
if (
self.attn_implementation == "sage"
and not self.sample_packing
and not self.pad_to_sequence_len
):
LOG.warning(
"We found loss to drop to 0 with SageAttention full finetuning."
"Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method."
"We recommend turning on `pad_to_sequence_len` for SageAttention "
"without packing. The loss has been observed to explode otherwise."
)
return data
return self
@model_validator(mode="after")
def check_sageattn_fft(self):
if self.attn_implementation == "sage" and not self.adapter:
LOG.warning(
"SageAttention full finetuning has been observed to drop loss to 0. "
"Monitor the loss, or switch to LoRA/QLoRA or another attention method."
)
return self
@model_validator(mode="before")
@classmethod
@@ -1575,17 +1570,13 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return self
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_sdpa_bf16(cls, data):
is_sm_90: bool = (
data["capabilities"]
and data["capabilities"].get("compute_capability") == "sm_90"
)
@model_validator(mode="after")
def check_sample_packing_w_sdpa_bf16(self):
is_sm_90 = self.capabilities and self.capabilities.compute_capability == "sm_90"
if (
data.get("sample_packing")
and (data.get("sdp_attention") or data.get("attn_implementation") == "sdpa")
and (data.get("bfloat16") or data.get("bf16"))
self.sample_packing
and self.attn_implementation == "sdpa"
and (self.bfloat16 or self.bf16)
and not is_sm_90
):
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
@@ -1593,26 +1584,51 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
"This may work on H100s."
)
return self
return data
@model_validator(mode="before")
@classmethod
def check_compute_capability_w_sageattn(cls, data):
is_sage = (
data.get("sage_attention") or data.get("attn_implementation") == "sage"
)
@model_validator(mode="after")
def check_compute_capability_w_sageattn(self):
if (
is_sage
and data.get("capabilities")
and data.get("capabilities").get("compute_capability")
self.attn_implementation == "sage"
and self.capabilities
and self.capabilities.compute_capability
not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"]
):
raise ValueError(
"SageAttention supports compute capability between sm_80 and sm_120. "
"Please use a different attention implementation."
)
return data
return self
@model_validator(mode="after")
def check_fp8_attention_preflight(self):
"""fp8 attention requires SM90+ and torch >= 2.11 (torchao >= 0.17 is pinned)."""
if self.attn_implementation != "fp8":
return self
if self.capabilities and self.capabilities.compute_capability:
cc = self.capabilities.compute_capability
# Accept sm_90 (H100/H200), sm_100 (B100/B200), sm_120 (B300-class).
if not cc.startswith("sm_") or int(cc.split("_", 1)[1]) < 90:
raise ValueError(
f"attn_implementation=fp8 requires compute capability sm_90 or "
f"higher (Hopper+). Detected {cc!r}."
)
torch_version = (
self.env_capabilities.torch_version if self.env_capabilities else None
)
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.11.0"):
raise ValueError(
f"attn_implementation=fp8 requires PyTorch >= 2.11.0. "
f"Detected {torch_version}."
)
return self
@model_validator(mode="before")
@classmethod
@@ -1768,16 +1784,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return data
@model_validator(mode="before")
@classmethod
def check_flex_torch_version(cls, data):
if (
data.get("flex_attention")
or data.get("attn_implementation") == "flex_attention"
):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
@model_validator(mode="after")
def check_flex_torch_version(self):
if self.attn_implementation == "flex_attention":
torch_version = (
self.env_capabilities.torch_version if self.env_capabilities else None
)
if torch_version is None:
import torch
@@ -1787,7 +1799,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
raise ValueError(
"Flex attention is not supported on torch version < 2.6.0"
)
return data
return self
@model_validator(mode="before")
@classmethod

View File

@@ -13,7 +13,6 @@ from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import (
ATTN_IMPLS_SUPPORTING_PACKING,
ChatTemplate,
RingAttnFunc,
RLType,
@@ -187,47 +186,42 @@ class AttentionValidationMixin:
# `check_attention_fields` was removed — `AxolotlInputConfig.normalize_attn_implementation`
# is now the single entry point for attention-input mapping and conflict detection.
@model_validator(mode="before")
@classmethod
def check_sample_packing_without_attention(cls, data):
if (
data.get("sample_packing")
and not data.get("attn_implementation")
and not data.get("flash_attention")
and not data.get("sdp_attention")
and not data.get("flex_attention")
and not data.get("xformers_attention")
and not data.get("sage_attention")
):
LOG.warning(
"sample_packing without flash, sdp, xformers, sage, or flex attention does not handle cross sample decontamination."
)
return data
@model_validator(mode="after")
def check_sample_packing_without_attention(self):
if self.sample_packing and not self.attn_supports_packing:
if self.attn_implementation:
LOG.warning(
"`sample_packing` with `attn_implementation=%r` does not handle "
"cross-sample decontamination. Use a varlen-capable backend "
"(e.g. flash_attention_2, flex_attention, xformers, sage) to "
"isolate samples.",
self.attn_implementation,
)
else:
LOG.warning(
"`sample_packing` without an attention backend does not handle "
"cross-sample decontamination. Set `attn_implementation` to a "
"varlen-capable backend (e.g. flash_attention_2)."
)
return self
@model_validator(mode="before")
@classmethod
def check_sample_packing_with_s2attn(cls, data):
if data.get("sample_packing") and (
data.get("s2_attention") or data.get("attn_implementation") == "s2"
):
@model_validator(mode="after")
def check_sample_packing_with_s2attn(self):
if self.sample_packing and self.attn_implementation == "s2":
raise ValueError(
"Received `sample_packing=true` and `s2_attention=true`; however, \
shifted-sparse attention does not currently support sample packing."
"Received `sample_packing=true` and `attn_implementation=s2`; "
"shifted-sparse attention does not currently support sample packing."
)
return data
return self
@model_validator(mode="before")
@classmethod
def check_scaling_softmax_requires_flex(cls, data):
if data.get("scaling_softmax") and not (
data.get("flex_attention")
or data.get("attn_implementation") == "flex_attention"
):
@model_validator(mode="after")
def check_scaling_softmax_requires_flex(self):
if self.scaling_softmax and self.attn_implementation != "flex_attention":
raise ValueError(
"scaling_softmax requires flex attention.\n"
"Add 'attn_implementation: flex' to your config file.\n"
"scaling_softmax requires flex attention. "
"Add `attn_implementation: flex_attention` to your config."
)
return data
return self
class TrainingValidationMixin:
@@ -933,48 +927,45 @@ class OptimizationValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_batch_flattening_fa(cls, data):
if data.get("batch_flattening"):
batch_flattening_auto = data.get("batch_flattening") == "auto"
has_varlen_attn = (
data.get("attn_implementation") in ATTN_IMPLS_SUPPORTING_PACKING
if data.get("attn_implementation")
else data.get("flash_attention")
@model_validator(mode="after")
def check_batch_flattening_fa(self):
if not self.batch_flattening:
return self
batch_flattening_auto = self.batch_flattening == "auto"
has_varlen_attn = self.attn_supports_packing
if not has_varlen_attn and not batch_flattening_auto:
raise ValueError(
"batch_flattening requires a varlen-capable attention backend "
"(e.g., attn_implementation: flash_attention_2)."
)
if not has_varlen_attn and not batch_flattening_auto:
raise ValueError(
"batch_flattening requires a varlen-capable attention backend "
"(e.g., attn_implementation: flash)"
)
if data.get("sample_packing") and not batch_flattening_auto:
raise ValueError("batch_flattening not compatible with sample_packing")
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:
LOG.warning("batch_flattening has no effect with micro_batch_size == 1")
if self.sample_packing and not batch_flattening_auto:
raise ValueError("batch_flattening not compatible with sample_packing")
if self.micro_batch_size == 1 and not batch_flattening_auto:
LOG.warning("batch_flattening has no effect with micro_batch_size == 1")
# Liger loss takes a separate code path (compute_liger_loss) that
# bypasses the flattened training forward pass. Batch flattening
# still applies to the scoring/deferred logprobs path.
trl_cfg = data.get("trl") or {}
if isinstance(trl_cfg, dict) and trl_cfg.get("use_liger_loss"):
LOG.warning(
"batch_flattening with use_liger_loss: flattening will only "
"apply to the scoring path (deferred logprobs). The training "
"forward pass uses Liger's fused lm_head+loss kernel instead."
)
# Liger loss takes a separate code path (compute_liger_loss) that
# bypasses the flattened training forward pass. Batch flattening
# still applies to the scoring/deferred logprobs path.
if self.trl and getattr(self.trl, "use_liger_loss", False):
LOG.warning(
"batch_flattening with use_liger_loss: flattening will only "
"apply to the scoring path (deferred logprobs). The training "
"forward pass uses Liger's fused lm_head+loss kernel instead."
)
if (
batch_flattening_auto
and has_varlen_attn
and not data.get("sample_packing")
and data.get("micro_batch_size") > 1
):
data["batch_flattening"] = True
elif batch_flattening_auto:
data["batch_flattening"] = False
if (
batch_flattening_auto
and has_varlen_attn
and not self.sample_packing
and self.micro_batch_size > 1
):
self.batch_flattening = True
elif batch_flattening_auto:
self.batch_flattening = False
return data
return self
@model_validator(mode="before")
@classmethod
@@ -1211,12 +1202,18 @@ class SystemValidationMixin:
def check_npu_config(cls, data):
if is_torch_npu_available():
# check attention config
unsupported_npu_impls = {"flash", "sdpa", "s2"}
unsupported_npu_impls = {
"flash_attention_2",
"flash_attention_3",
"sdpa",
"s2",
}
attn_impl = data.get("attn_implementation")
if attn_impl and attn_impl in unsupported_npu_impls:
raise NotImplementedError(
f"attn_implementation={attn_impl!r} is currently not supported on Ascend NPU."
)
# Legacy flags still present at this point (normalizer strips them later).
attn_list = ["flash_attention", "sdp_attention", "s2_attention"]
for attn in attn_list:
if data.get(attn):
@@ -1658,50 +1655,46 @@ class EBFTValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_ebft_gradient_checkpointing_reentrant(cls, data):
@model_validator(mode="after")
def check_ebft_gradient_checkpointing_reentrant(self):
"""flex_attention + non-reentrant gradient checkpointing causes CheckpointError."""
if (
data.get("rl") == "ebft"
and data.get("ebft", {}).get("mode") == "strided"
and (
data.get("flex_attention")
or data.get("attn_implementation") == "flex_attention"
)
and data.get("gradient_checkpointing")
self.rl == "ebft"
and (self.ebft or {}).get("mode") == "strided"
and self.attn_implementation == "flex_attention"
and self.gradient_checkpointing
):
gc_kwargs = data.get("gradient_checkpointing_kwargs") or {}
gc_kwargs = self.gradient_checkpointing_kwargs or {}
if not gc_kwargs.get("use_reentrant"):
LOG.warning(
"EBFT strided mode with flex_attention: setting `use_reentrant: true` in "
"gradient_checkpointing_kwargs (required for flex_attention compatibility). "
"Non-reentrant checkpointing causes CheckpointError with BlockMask metadata."
)
if data.get("gradient_checkpointing_kwargs") is None:
data["gradient_checkpointing_kwargs"] = {}
data["gradient_checkpointing_kwargs"]["use_reentrant"] = True
return data
if self.gradient_checkpointing_kwargs is None:
self.gradient_checkpointing_kwargs = {}
self.gradient_checkpointing_kwargs["use_reentrant"] = True
return self
@model_validator(mode="before")
@classmethod
def check_ebft_activation_offloading(cls, data):
@model_validator(mode="after")
def check_ebft_activation_offloading(self):
"""activation_offloading replaces gradient checkpointing with FSDP-style wrapping,
which conflicts with flex_attention's use_reentrant requirement."""
if (
data.get("rl") == "ebft"
and data.get("ebft", {}).get("mode") == "strided"
and data.get("activation_offloading") is True
and data.get("flex_attention")
self.rl == "ebft"
and (self.ebft or {}).get("mode") == "strided"
and self.activation_offloading is True
and self.attn_implementation == "flex_attention"
):
raise ValueError(
"EBFT strided mode: `activation_offloading: true` is incompatible with "
"`flex_attention: true`. Activation offloading replaces gradient checkpointing "
"with FSDP-style wrapping that conflicts with flex_attention's reentrant "
"checkpoint requirement. Remove `activation_offloading` — the strided trainer "
"uses micro-batched forward passes for memory efficiency instead."
"`attn_implementation: flex_attention`. Activation offloading replaces "
"gradient checkpointing with FSDP-style wrapping that conflicts with "
"flex_attention's reentrant checkpoint requirement. Remove "
"`activation_offloading` — the strided trainer uses micro-batched forward "
"passes for memory efficiency instead."
)
return data
return self
@model_validator(mode="before")
@classmethod