move attention-dependent validators to mode=after
This commit is contained in:
@@ -1474,32 +1474,27 @@ class AxolotlInputConfig(
|
||||
f"path containing '/'."
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sageattn_wo_sample_packing(cls, data):
|
||||
is_sage = (
|
||||
data.get("sage_attention") or data.get("attn_implementation") == "sage"
|
||||
)
|
||||
if (not data.get("sample_packing", False)) and is_sage:
|
||||
if not data.get("pad_to_sequence_len", False):
|
||||
LOG.warning(
|
||||
"We recommend turning on `pad_to_sequence_len` for SageAttention without packing."
|
||||
"This is because there has been signs that the loss explodes after a few steps."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sageattn_fft(cls, data):
|
||||
is_sage = (
|
||||
data.get("sage_attention") or data.get("attn_implementation") == "sage"
|
||||
)
|
||||
if (not data.get("adapter", False)) and is_sage:
|
||||
@model_validator(mode="after")
|
||||
def check_sageattn_wo_sample_packing(self):
|
||||
if (
|
||||
self.attn_implementation == "sage"
|
||||
and not self.sample_packing
|
||||
and not self.pad_to_sequence_len
|
||||
):
|
||||
LOG.warning(
|
||||
"We found loss to drop to 0 with SageAttention full finetuning."
|
||||
"Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method."
|
||||
"We recommend turning on `pad_to_sequence_len` for SageAttention "
|
||||
"without packing. The loss has been observed to explode otherwise."
|
||||
)
|
||||
return data
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_sageattn_fft(self):
|
||||
if self.attn_implementation == "sage" and not self.adapter:
|
||||
LOG.warning(
|
||||
"SageAttention full finetuning has been observed to drop loss to 0. "
|
||||
"Monitor the loss, or switch to LoRA/QLoRA or another attention method."
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -1575,17 +1570,13 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sample_packing_w_sdpa_bf16(cls, data):
|
||||
is_sm_90: bool = (
|
||||
data["capabilities"]
|
||||
and data["capabilities"].get("compute_capability") == "sm_90"
|
||||
)
|
||||
@model_validator(mode="after")
|
||||
def check_sample_packing_w_sdpa_bf16(self):
|
||||
is_sm_90 = self.capabilities and self.capabilities.compute_capability == "sm_90"
|
||||
if (
|
||||
data.get("sample_packing")
|
||||
and (data.get("sdp_attention") or data.get("attn_implementation") == "sdpa")
|
||||
and (data.get("bfloat16") or data.get("bf16"))
|
||||
self.sample_packing
|
||||
and self.attn_implementation == "sdpa"
|
||||
and (self.bfloat16 or self.bf16)
|
||||
and not is_sm_90
|
||||
):
|
||||
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
|
||||
@@ -1593,26 +1584,51 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
|
||||
"This may work on H100s."
|
||||
)
|
||||
return self
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_compute_capability_w_sageattn(cls, data):
|
||||
is_sage = (
|
||||
data.get("sage_attention") or data.get("attn_implementation") == "sage"
|
||||
)
|
||||
@model_validator(mode="after")
|
||||
def check_compute_capability_w_sageattn(self):
|
||||
if (
|
||||
is_sage
|
||||
and data.get("capabilities")
|
||||
and data.get("capabilities").get("compute_capability")
|
||||
self.attn_implementation == "sage"
|
||||
and self.capabilities
|
||||
and self.capabilities.compute_capability
|
||||
not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"]
|
||||
):
|
||||
raise ValueError(
|
||||
"SageAttention supports compute capability between sm_80 and sm_120. "
|
||||
"Please use a different attention implementation."
|
||||
)
|
||||
return data
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_fp8_attention_preflight(self):
|
||||
"""fp8 attention requires SM90+ and torch >= 2.11 (torchao >= 0.17 is pinned)."""
|
||||
if self.attn_implementation != "fp8":
|
||||
return self
|
||||
|
||||
if self.capabilities and self.capabilities.compute_capability:
|
||||
cc = self.capabilities.compute_capability
|
||||
# Accept sm_90 (H100/H200), sm_100 (B100/B200), sm_120 (B300-class).
|
||||
if not cc.startswith("sm_") or int(cc.split("_", 1)[1]) < 90:
|
||||
raise ValueError(
|
||||
f"attn_implementation=fp8 requires compute capability sm_90 or "
|
||||
f"higher (Hopper+). Detected {cc!r}."
|
||||
)
|
||||
|
||||
torch_version = (
|
||||
self.env_capabilities.torch_version if self.env_capabilities else None
|
||||
)
|
||||
if torch_version is None:
|
||||
import torch
|
||||
|
||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
if version.parse(torch_version) < version.parse("2.11.0"):
|
||||
raise ValueError(
|
||||
f"attn_implementation=fp8 requires PyTorch >= 2.11.0. "
|
||||
f"Detected {torch_version}."
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -1768,16 +1784,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_flex_torch_version(cls, data):
|
||||
if (
|
||||
data.get("flex_attention")
|
||||
or data.get("attn_implementation") == "flex_attention"
|
||||
):
|
||||
env_capabilities = data.get("env_capabilities", {})
|
||||
torch_version = env_capabilities.get("torch_version")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_flex_torch_version(self):
|
||||
if self.attn_implementation == "flex_attention":
|
||||
torch_version = (
|
||||
self.env_capabilities.torch_version if self.env_capabilities else None
|
||||
)
|
||||
if torch_version is None:
|
||||
import torch
|
||||
|
||||
@@ -1787,7 +1799,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
raise ValueError(
|
||||
"Flex attention is not supported on torch version < 2.6.0"
|
||||
)
|
||||
return data
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -13,7 +13,6 @@ from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import (
|
||||
ATTN_IMPLS_SUPPORTING_PACKING,
|
||||
ChatTemplate,
|
||||
RingAttnFunc,
|
||||
RLType,
|
||||
@@ -187,47 +186,42 @@ class AttentionValidationMixin:
|
||||
# `check_attention_fields` was removed — `AxolotlInputConfig.normalize_attn_implementation`
|
||||
# is now the single entry point for attention-input mapping and conflict detection.
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sample_packing_without_attention(cls, data):
|
||||
if (
|
||||
data.get("sample_packing")
|
||||
and not data.get("attn_implementation")
|
||||
and not data.get("flash_attention")
|
||||
and not data.get("sdp_attention")
|
||||
and not data.get("flex_attention")
|
||||
and not data.get("xformers_attention")
|
||||
and not data.get("sage_attention")
|
||||
):
|
||||
LOG.warning(
|
||||
"sample_packing without flash, sdp, xformers, sage, or flex attention does not handle cross sample decontamination."
|
||||
)
|
||||
return data
|
||||
@model_validator(mode="after")
|
||||
def check_sample_packing_without_attention(self):
|
||||
if self.sample_packing and not self.attn_supports_packing:
|
||||
if self.attn_implementation:
|
||||
LOG.warning(
|
||||
"`sample_packing` with `attn_implementation=%r` does not handle "
|
||||
"cross-sample decontamination. Use a varlen-capable backend "
|
||||
"(e.g. flash_attention_2, flex_attention, xformers, sage) to "
|
||||
"isolate samples.",
|
||||
self.attn_implementation,
|
||||
)
|
||||
else:
|
||||
LOG.warning(
|
||||
"`sample_packing` without an attention backend does not handle "
|
||||
"cross-sample decontamination. Set `attn_implementation` to a "
|
||||
"varlen-capable backend (e.g. flash_attention_2)."
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sample_packing_with_s2attn(cls, data):
|
||||
if data.get("sample_packing") and (
|
||||
data.get("s2_attention") or data.get("attn_implementation") == "s2"
|
||||
):
|
||||
@model_validator(mode="after")
|
||||
def check_sample_packing_with_s2attn(self):
|
||||
if self.sample_packing and self.attn_implementation == "s2":
|
||||
raise ValueError(
|
||||
"Received `sample_packing=true` and `s2_attention=true`; however, \
|
||||
shifted-sparse attention does not currently support sample packing."
|
||||
"Received `sample_packing=true` and `attn_implementation=s2`; "
|
||||
"shifted-sparse attention does not currently support sample packing."
|
||||
)
|
||||
return data
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_scaling_softmax_requires_flex(cls, data):
|
||||
if data.get("scaling_softmax") and not (
|
||||
data.get("flex_attention")
|
||||
or data.get("attn_implementation") == "flex_attention"
|
||||
):
|
||||
@model_validator(mode="after")
|
||||
def check_scaling_softmax_requires_flex(self):
|
||||
if self.scaling_softmax and self.attn_implementation != "flex_attention":
|
||||
raise ValueError(
|
||||
"scaling_softmax requires flex attention.\n"
|
||||
"Add 'attn_implementation: flex' to your config file.\n"
|
||||
"scaling_softmax requires flex attention. "
|
||||
"Add `attn_implementation: flex_attention` to your config."
|
||||
)
|
||||
return data
|
||||
return self
|
||||
|
||||
|
||||
class TrainingValidationMixin:
|
||||
@@ -933,48 +927,45 @@ class OptimizationValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_batch_flattening_fa(cls, data):
|
||||
if data.get("batch_flattening"):
|
||||
batch_flattening_auto = data.get("batch_flattening") == "auto"
|
||||
has_varlen_attn = (
|
||||
data.get("attn_implementation") in ATTN_IMPLS_SUPPORTING_PACKING
|
||||
if data.get("attn_implementation")
|
||||
else data.get("flash_attention")
|
||||
@model_validator(mode="after")
|
||||
def check_batch_flattening_fa(self):
|
||||
if not self.batch_flattening:
|
||||
return self
|
||||
|
||||
batch_flattening_auto = self.batch_flattening == "auto"
|
||||
has_varlen_attn = self.attn_supports_packing
|
||||
|
||||
if not has_varlen_attn and not batch_flattening_auto:
|
||||
raise ValueError(
|
||||
"batch_flattening requires a varlen-capable attention backend "
|
||||
"(e.g., attn_implementation: flash_attention_2)."
|
||||
)
|
||||
if not has_varlen_attn and not batch_flattening_auto:
|
||||
raise ValueError(
|
||||
"batch_flattening requires a varlen-capable attention backend "
|
||||
"(e.g., attn_implementation: flash)"
|
||||
)
|
||||
if data.get("sample_packing") and not batch_flattening_auto:
|
||||
raise ValueError("batch_flattening not compatible with sample_packing")
|
||||
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:
|
||||
LOG.warning("batch_flattening has no effect with micro_batch_size == 1")
|
||||
if self.sample_packing and not batch_flattening_auto:
|
||||
raise ValueError("batch_flattening not compatible with sample_packing")
|
||||
if self.micro_batch_size == 1 and not batch_flattening_auto:
|
||||
LOG.warning("batch_flattening has no effect with micro_batch_size == 1")
|
||||
|
||||
# Liger loss takes a separate code path (compute_liger_loss) that
|
||||
# bypasses the flattened training forward pass. Batch flattening
|
||||
# still applies to the scoring/deferred logprobs path.
|
||||
trl_cfg = data.get("trl") or {}
|
||||
if isinstance(trl_cfg, dict) and trl_cfg.get("use_liger_loss"):
|
||||
LOG.warning(
|
||||
"batch_flattening with use_liger_loss: flattening will only "
|
||||
"apply to the scoring path (deferred logprobs). The training "
|
||||
"forward pass uses Liger's fused lm_head+loss kernel instead."
|
||||
)
|
||||
# Liger loss takes a separate code path (compute_liger_loss) that
|
||||
# bypasses the flattened training forward pass. Batch flattening
|
||||
# still applies to the scoring/deferred logprobs path.
|
||||
if self.trl and getattr(self.trl, "use_liger_loss", False):
|
||||
LOG.warning(
|
||||
"batch_flattening with use_liger_loss: flattening will only "
|
||||
"apply to the scoring path (deferred logprobs). The training "
|
||||
"forward pass uses Liger's fused lm_head+loss kernel instead."
|
||||
)
|
||||
|
||||
if (
|
||||
batch_flattening_auto
|
||||
and has_varlen_attn
|
||||
and not data.get("sample_packing")
|
||||
and data.get("micro_batch_size") > 1
|
||||
):
|
||||
data["batch_flattening"] = True
|
||||
elif batch_flattening_auto:
|
||||
data["batch_flattening"] = False
|
||||
if (
|
||||
batch_flattening_auto
|
||||
and has_varlen_attn
|
||||
and not self.sample_packing
|
||||
and self.micro_batch_size > 1
|
||||
):
|
||||
self.batch_flattening = True
|
||||
elif batch_flattening_auto:
|
||||
self.batch_flattening = False
|
||||
|
||||
return data
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -1211,12 +1202,18 @@ class SystemValidationMixin:
|
||||
def check_npu_config(cls, data):
|
||||
if is_torch_npu_available():
|
||||
# check attention config
|
||||
unsupported_npu_impls = {"flash", "sdpa", "s2"}
|
||||
unsupported_npu_impls = {
|
||||
"flash_attention_2",
|
||||
"flash_attention_3",
|
||||
"sdpa",
|
||||
"s2",
|
||||
}
|
||||
attn_impl = data.get("attn_implementation")
|
||||
if attn_impl and attn_impl in unsupported_npu_impls:
|
||||
raise NotImplementedError(
|
||||
f"attn_implementation={attn_impl!r} is currently not supported on Ascend NPU."
|
||||
)
|
||||
# Legacy flags still present at this point (normalizer strips them later).
|
||||
attn_list = ["flash_attention", "sdp_attention", "s2_attention"]
|
||||
for attn in attn_list:
|
||||
if data.get(attn):
|
||||
@@ -1658,50 +1655,46 @@ class EBFTValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_ebft_gradient_checkpointing_reentrant(cls, data):
|
||||
@model_validator(mode="after")
|
||||
def check_ebft_gradient_checkpointing_reentrant(self):
|
||||
"""flex_attention + non-reentrant gradient checkpointing causes CheckpointError."""
|
||||
if (
|
||||
data.get("rl") == "ebft"
|
||||
and data.get("ebft", {}).get("mode") == "strided"
|
||||
and (
|
||||
data.get("flex_attention")
|
||||
or data.get("attn_implementation") == "flex_attention"
|
||||
)
|
||||
and data.get("gradient_checkpointing")
|
||||
self.rl == "ebft"
|
||||
and (self.ebft or {}).get("mode") == "strided"
|
||||
and self.attn_implementation == "flex_attention"
|
||||
and self.gradient_checkpointing
|
||||
):
|
||||
gc_kwargs = data.get("gradient_checkpointing_kwargs") or {}
|
||||
gc_kwargs = self.gradient_checkpointing_kwargs or {}
|
||||
if not gc_kwargs.get("use_reentrant"):
|
||||
LOG.warning(
|
||||
"EBFT strided mode with flex_attention: setting `use_reentrant: true` in "
|
||||
"gradient_checkpointing_kwargs (required for flex_attention compatibility). "
|
||||
"Non-reentrant checkpointing causes CheckpointError with BlockMask metadata."
|
||||
)
|
||||
if data.get("gradient_checkpointing_kwargs") is None:
|
||||
data["gradient_checkpointing_kwargs"] = {}
|
||||
data["gradient_checkpointing_kwargs"]["use_reentrant"] = True
|
||||
return data
|
||||
if self.gradient_checkpointing_kwargs is None:
|
||||
self.gradient_checkpointing_kwargs = {}
|
||||
self.gradient_checkpointing_kwargs["use_reentrant"] = True
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_ebft_activation_offloading(cls, data):
|
||||
@model_validator(mode="after")
|
||||
def check_ebft_activation_offloading(self):
|
||||
"""activation_offloading replaces gradient checkpointing with FSDP-style wrapping,
|
||||
which conflicts with flex_attention's use_reentrant requirement."""
|
||||
if (
|
||||
data.get("rl") == "ebft"
|
||||
and data.get("ebft", {}).get("mode") == "strided"
|
||||
and data.get("activation_offloading") is True
|
||||
and data.get("flex_attention")
|
||||
self.rl == "ebft"
|
||||
and (self.ebft or {}).get("mode") == "strided"
|
||||
and self.activation_offloading is True
|
||||
and self.attn_implementation == "flex_attention"
|
||||
):
|
||||
raise ValueError(
|
||||
"EBFT strided mode: `activation_offloading: true` is incompatible with "
|
||||
"`flex_attention: true`. Activation offloading replaces gradient checkpointing "
|
||||
"with FSDP-style wrapping that conflicts with flex_attention's reentrant "
|
||||
"checkpoint requirement. Remove `activation_offloading` — the strided trainer "
|
||||
"uses micro-batched forward passes for memory efficiency instead."
|
||||
"`attn_implementation: flex_attention`. Activation offloading replaces "
|
||||
"gradient checkpointing with FSDP-style wrapping that conflicts with "
|
||||
"flex_attention's reentrant checkpoint requirement. Remove "
|
||||
"`activation_offloading` — the strided trainer uses micro-batched forward "
|
||||
"passes for memory efficiency instead."
|
||||
)
|
||||
return data
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user