move attention-dependent validators to mode=after
This commit is contained in:
@@ -1474,32 +1474,27 @@ class AxolotlInputConfig(
|
|||||||
f"path containing '/'."
|
f"path containing '/'."
|
||||||
)
|
)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="after")
|
||||||
@classmethod
|
def check_sageattn_wo_sample_packing(self):
|
||||||
def check_sageattn_wo_sample_packing(cls, data):
|
if (
|
||||||
is_sage = (
|
self.attn_implementation == "sage"
|
||||||
data.get("sage_attention") or data.get("attn_implementation") == "sage"
|
and not self.sample_packing
|
||||||
)
|
and not self.pad_to_sequence_len
|
||||||
if (not data.get("sample_packing", False)) and is_sage:
|
):
|
||||||
if not data.get("pad_to_sequence_len", False):
|
|
||||||
LOG.warning(
|
|
||||||
"We recommend turning on `pad_to_sequence_len` for SageAttention without packing."
|
|
||||||
"This is because there has been signs that the loss explodes after a few steps."
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_sageattn_fft(cls, data):
|
|
||||||
is_sage = (
|
|
||||||
data.get("sage_attention") or data.get("attn_implementation") == "sage"
|
|
||||||
)
|
|
||||||
if (not data.get("adapter", False)) and is_sage:
|
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"We found loss to drop to 0 with SageAttention full finetuning."
|
"We recommend turning on `pad_to_sequence_len` for SageAttention "
|
||||||
"Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method."
|
"without packing. The loss has been observed to explode otherwise."
|
||||||
)
|
)
|
||||||
return data
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_sageattn_fft(self):
|
||||||
|
if self.attn_implementation == "sage" and not self.adapter:
|
||||||
|
LOG.warning(
|
||||||
|
"SageAttention full finetuning has been observed to drop loss to 0. "
|
||||||
|
"Monitor the loss, or switch to LoRA/QLoRA or another attention method."
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -1575,17 +1570,13 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="after")
|
||||||
@classmethod
|
def check_sample_packing_w_sdpa_bf16(self):
|
||||||
def check_sample_packing_w_sdpa_bf16(cls, data):
|
is_sm_90 = self.capabilities and self.capabilities.compute_capability == "sm_90"
|
||||||
is_sm_90: bool = (
|
|
||||||
data["capabilities"]
|
|
||||||
and data["capabilities"].get("compute_capability") == "sm_90"
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
data.get("sample_packing")
|
self.sample_packing
|
||||||
and (data.get("sdp_attention") or data.get("attn_implementation") == "sdpa")
|
and self.attn_implementation == "sdpa"
|
||||||
and (data.get("bfloat16") or data.get("bf16"))
|
and (self.bfloat16 or self.bf16)
|
||||||
and not is_sm_90
|
and not is_sm_90
|
||||||
):
|
):
|
||||||
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
|
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
|
||||||
@@ -1593,26 +1584,51 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
|
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
|
||||||
"This may work on H100s."
|
"This may work on H100s."
|
||||||
)
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
return data
|
@model_validator(mode="after")
|
||||||
|
def check_compute_capability_w_sageattn(self):
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_compute_capability_w_sageattn(cls, data):
|
|
||||||
is_sage = (
|
|
||||||
data.get("sage_attention") or data.get("attn_implementation") == "sage"
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
is_sage
|
self.attn_implementation == "sage"
|
||||||
and data.get("capabilities")
|
and self.capabilities
|
||||||
and data.get("capabilities").get("compute_capability")
|
and self.capabilities.compute_capability
|
||||||
not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"]
|
not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"]
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"SageAttention supports compute capability between sm_80 and sm_120. "
|
"SageAttention supports compute capability between sm_80 and sm_120. "
|
||||||
"Please use a different attention implementation."
|
"Please use a different attention implementation."
|
||||||
)
|
)
|
||||||
return data
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_fp8_attention_preflight(self):
|
||||||
|
"""fp8 attention requires SM90+ and torch >= 2.11 (torchao >= 0.17 is pinned)."""
|
||||||
|
if self.attn_implementation != "fp8":
|
||||||
|
return self
|
||||||
|
|
||||||
|
if self.capabilities and self.capabilities.compute_capability:
|
||||||
|
cc = self.capabilities.compute_capability
|
||||||
|
# Accept sm_90 (H100/H200), sm_100 (B100/B200), sm_120 (B300-class).
|
||||||
|
if not cc.startswith("sm_") or int(cc.split("_", 1)[1]) < 90:
|
||||||
|
raise ValueError(
|
||||||
|
f"attn_implementation=fp8 requires compute capability sm_90 or "
|
||||||
|
f"higher (Hopper+). Detected {cc!r}."
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_version = (
|
||||||
|
self.env_capabilities.torch_version if self.env_capabilities else None
|
||||||
|
)
|
||||||
|
if torch_version is None:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||||
|
if version.parse(torch_version) < version.parse("2.11.0"):
|
||||||
|
raise ValueError(
|
||||||
|
f"attn_implementation=fp8 requires PyTorch >= 2.11.0. "
|
||||||
|
f"Detected {torch_version}."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -1768,16 +1784,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="after")
|
||||||
@classmethod
|
def check_flex_torch_version(self):
|
||||||
def check_flex_torch_version(cls, data):
|
if self.attn_implementation == "flex_attention":
|
||||||
if (
|
torch_version = (
|
||||||
data.get("flex_attention")
|
self.env_capabilities.torch_version if self.env_capabilities else None
|
||||||
or data.get("attn_implementation") == "flex_attention"
|
)
|
||||||
):
|
|
||||||
env_capabilities = data.get("env_capabilities", {})
|
|
||||||
torch_version = env_capabilities.get("torch_version")
|
|
||||||
|
|
||||||
if torch_version is None:
|
if torch_version is None:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -1787,7 +1799,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Flex attention is not supported on torch version < 2.6.0"
|
"Flex attention is not supported on torch version < 2.6.0"
|
||||||
)
|
)
|
||||||
return data
|
return self
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from transformers.utils.import_utils import is_torch_npu_available
|
|||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.enums import (
|
from axolotl.utils.schemas.enums import (
|
||||||
ATTN_IMPLS_SUPPORTING_PACKING,
|
|
||||||
ChatTemplate,
|
ChatTemplate,
|
||||||
RingAttnFunc,
|
RingAttnFunc,
|
||||||
RLType,
|
RLType,
|
||||||
@@ -187,47 +186,42 @@ class AttentionValidationMixin:
|
|||||||
# `check_attention_fields` was removed — `AxolotlInputConfig.normalize_attn_implementation`
|
# `check_attention_fields` was removed — `AxolotlInputConfig.normalize_attn_implementation`
|
||||||
# is now the single entry point for attention-input mapping and conflict detection.
|
# is now the single entry point for attention-input mapping and conflict detection.
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="after")
|
||||||
@classmethod
|
def check_sample_packing_without_attention(self):
|
||||||
def check_sample_packing_without_attention(cls, data):
|
if self.sample_packing and not self.attn_supports_packing:
|
||||||
if (
|
if self.attn_implementation:
|
||||||
data.get("sample_packing")
|
LOG.warning(
|
||||||
and not data.get("attn_implementation")
|
"`sample_packing` with `attn_implementation=%r` does not handle "
|
||||||
and not data.get("flash_attention")
|
"cross-sample decontamination. Use a varlen-capable backend "
|
||||||
and not data.get("sdp_attention")
|
"(e.g. flash_attention_2, flex_attention, xformers, sage) to "
|
||||||
and not data.get("flex_attention")
|
"isolate samples.",
|
||||||
and not data.get("xformers_attention")
|
self.attn_implementation,
|
||||||
and not data.get("sage_attention")
|
)
|
||||||
):
|
else:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"sample_packing without flash, sdp, xformers, sage, or flex attention does not handle cross sample decontamination."
|
"`sample_packing` without an attention backend does not handle "
|
||||||
)
|
"cross-sample decontamination. Set `attn_implementation` to a "
|
||||||
return data
|
"varlen-capable backend (e.g. flash_attention_2)."
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="after")
|
||||||
@classmethod
|
def check_sample_packing_with_s2attn(self):
|
||||||
def check_sample_packing_with_s2attn(cls, data):
|
if self.sample_packing and self.attn_implementation == "s2":
|
||||||
if data.get("sample_packing") and (
|
|
||||||
data.get("s2_attention") or data.get("attn_implementation") == "s2"
|
|
||||||
):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Received `sample_packing=true` and `s2_attention=true`; however, \
|
"Received `sample_packing=true` and `attn_implementation=s2`; "
|
||||||
shifted-sparse attention does not currently support sample packing."
|
"shifted-sparse attention does not currently support sample packing."
|
||||||
)
|
)
|
||||||
return data
|
return self
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="after")
|
||||||
@classmethod
|
def check_scaling_softmax_requires_flex(self):
|
||||||
def check_scaling_softmax_requires_flex(cls, data):
|
if self.scaling_softmax and self.attn_implementation != "flex_attention":
|
||||||
if data.get("scaling_softmax") and not (
|
|
||||||
data.get("flex_attention")
|
|
||||||
or data.get("attn_implementation") == "flex_attention"
|
|
||||||
):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"scaling_softmax requires flex attention.\n"
|
"scaling_softmax requires flex attention. "
|
||||||
"Add 'attn_implementation: flex' to your config file.\n"
|
"Add `attn_implementation: flex_attention` to your config."
|
||||||
)
|
)
|
||||||
return data
|
return self
|
||||||
|
|
||||||
|
|
||||||
class TrainingValidationMixin:
|
class TrainingValidationMixin:
|
||||||
@@ -933,48 +927,45 @@ class OptimizationValidationMixin:
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="after")
|
||||||
@classmethod
|
def check_batch_flattening_fa(self):
|
||||||
def check_batch_flattening_fa(cls, data):
|
if not self.batch_flattening:
|
||||||
if data.get("batch_flattening"):
|
return self
|
||||||
batch_flattening_auto = data.get("batch_flattening") == "auto"
|
|
||||||
has_varlen_attn = (
|
batch_flattening_auto = self.batch_flattening == "auto"
|
||||||
data.get("attn_implementation") in ATTN_IMPLS_SUPPORTING_PACKING
|
has_varlen_attn = self.attn_supports_packing
|
||||||
if data.get("attn_implementation")
|
|
||||||
else data.get("flash_attention")
|
if not has_varlen_attn and not batch_flattening_auto:
|
||||||
|
raise ValueError(
|
||||||
|
"batch_flattening requires a varlen-capable attention backend "
|
||||||
|
"(e.g., attn_implementation: flash_attention_2)."
|
||||||
)
|
)
|
||||||
if not has_varlen_attn and not batch_flattening_auto:
|
if self.sample_packing and not batch_flattening_auto:
|
||||||
raise ValueError(
|
raise ValueError("batch_flattening not compatible with sample_packing")
|
||||||
"batch_flattening requires a varlen-capable attention backend "
|
if self.micro_batch_size == 1 and not batch_flattening_auto:
|
||||||
"(e.g., attn_implementation: flash)"
|
LOG.warning("batch_flattening has no effect with micro_batch_size == 1")
|
||||||
)
|
|
||||||
if data.get("sample_packing") and not batch_flattening_auto:
|
|
||||||
raise ValueError("batch_flattening not compatible with sample_packing")
|
|
||||||
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:
|
|
||||||
LOG.warning("batch_flattening has no effect with micro_batch_size == 1")
|
|
||||||
|
|
||||||
# Liger loss takes a separate code path (compute_liger_loss) that
|
# Liger loss takes a separate code path (compute_liger_loss) that
|
||||||
# bypasses the flattened training forward pass. Batch flattening
|
# bypasses the flattened training forward pass. Batch flattening
|
||||||
# still applies to the scoring/deferred logprobs path.
|
# still applies to the scoring/deferred logprobs path.
|
||||||
trl_cfg = data.get("trl") or {}
|
if self.trl and getattr(self.trl, "use_liger_loss", False):
|
||||||
if isinstance(trl_cfg, dict) and trl_cfg.get("use_liger_loss"):
|
LOG.warning(
|
||||||
LOG.warning(
|
"batch_flattening with use_liger_loss: flattening will only "
|
||||||
"batch_flattening with use_liger_loss: flattening will only "
|
"apply to the scoring path (deferred logprobs). The training "
|
||||||
"apply to the scoring path (deferred logprobs). The training "
|
"forward pass uses Liger's fused lm_head+loss kernel instead."
|
||||||
"forward pass uses Liger's fused lm_head+loss kernel instead."
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
batch_flattening_auto
|
batch_flattening_auto
|
||||||
and has_varlen_attn
|
and has_varlen_attn
|
||||||
and not data.get("sample_packing")
|
and not self.sample_packing
|
||||||
and data.get("micro_batch_size") > 1
|
and self.micro_batch_size > 1
|
||||||
):
|
):
|
||||||
data["batch_flattening"] = True
|
self.batch_flattening = True
|
||||||
elif batch_flattening_auto:
|
elif batch_flattening_auto:
|
||||||
data["batch_flattening"] = False
|
self.batch_flattening = False
|
||||||
|
|
||||||
return data
|
return self
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -1211,12 +1202,18 @@ class SystemValidationMixin:
|
|||||||
def check_npu_config(cls, data):
|
def check_npu_config(cls, data):
|
||||||
if is_torch_npu_available():
|
if is_torch_npu_available():
|
||||||
# check attention config
|
# check attention config
|
||||||
unsupported_npu_impls = {"flash", "sdpa", "s2"}
|
unsupported_npu_impls = {
|
||||||
|
"flash_attention_2",
|
||||||
|
"flash_attention_3",
|
||||||
|
"sdpa",
|
||||||
|
"s2",
|
||||||
|
}
|
||||||
attn_impl = data.get("attn_implementation")
|
attn_impl = data.get("attn_implementation")
|
||||||
if attn_impl and attn_impl in unsupported_npu_impls:
|
if attn_impl and attn_impl in unsupported_npu_impls:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"attn_implementation={attn_impl!r} is currently not supported on Ascend NPU."
|
f"attn_implementation={attn_impl!r} is currently not supported on Ascend NPU."
|
||||||
)
|
)
|
||||||
|
# Legacy flags still present at this point (normalizer strips them later).
|
||||||
attn_list = ["flash_attention", "sdp_attention", "s2_attention"]
|
attn_list = ["flash_attention", "sdp_attention", "s2_attention"]
|
||||||
for attn in attn_list:
|
for attn in attn_list:
|
||||||
if data.get(attn):
|
if data.get(attn):
|
||||||
@@ -1658,50 +1655,46 @@ class EBFTValidationMixin:
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="after")
|
||||||
@classmethod
|
def check_ebft_gradient_checkpointing_reentrant(self):
|
||||||
def check_ebft_gradient_checkpointing_reentrant(cls, data):
|
|
||||||
"""flex_attention + non-reentrant gradient checkpointing causes CheckpointError."""
|
"""flex_attention + non-reentrant gradient checkpointing causes CheckpointError."""
|
||||||
if (
|
if (
|
||||||
data.get("rl") == "ebft"
|
self.rl == "ebft"
|
||||||
and data.get("ebft", {}).get("mode") == "strided"
|
and (self.ebft or {}).get("mode") == "strided"
|
||||||
and (
|
and self.attn_implementation == "flex_attention"
|
||||||
data.get("flex_attention")
|
and self.gradient_checkpointing
|
||||||
or data.get("attn_implementation") == "flex_attention"
|
|
||||||
)
|
|
||||||
and data.get("gradient_checkpointing")
|
|
||||||
):
|
):
|
||||||
gc_kwargs = data.get("gradient_checkpointing_kwargs") or {}
|
gc_kwargs = self.gradient_checkpointing_kwargs or {}
|
||||||
if not gc_kwargs.get("use_reentrant"):
|
if not gc_kwargs.get("use_reentrant"):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"EBFT strided mode with flex_attention: setting `use_reentrant: true` in "
|
"EBFT strided mode with flex_attention: setting `use_reentrant: true` in "
|
||||||
"gradient_checkpointing_kwargs (required for flex_attention compatibility). "
|
"gradient_checkpointing_kwargs (required for flex_attention compatibility). "
|
||||||
"Non-reentrant checkpointing causes CheckpointError with BlockMask metadata."
|
"Non-reentrant checkpointing causes CheckpointError with BlockMask metadata."
|
||||||
)
|
)
|
||||||
if data.get("gradient_checkpointing_kwargs") is None:
|
if self.gradient_checkpointing_kwargs is None:
|
||||||
data["gradient_checkpointing_kwargs"] = {}
|
self.gradient_checkpointing_kwargs = {}
|
||||||
data["gradient_checkpointing_kwargs"]["use_reentrant"] = True
|
self.gradient_checkpointing_kwargs["use_reentrant"] = True
|
||||||
return data
|
return self
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="after")
|
||||||
@classmethod
|
def check_ebft_activation_offloading(self):
|
||||||
def check_ebft_activation_offloading(cls, data):
|
|
||||||
"""activation_offloading replaces gradient checkpointing with FSDP-style wrapping,
|
"""activation_offloading replaces gradient checkpointing with FSDP-style wrapping,
|
||||||
which conflicts with flex_attention's use_reentrant requirement."""
|
which conflicts with flex_attention's use_reentrant requirement."""
|
||||||
if (
|
if (
|
||||||
data.get("rl") == "ebft"
|
self.rl == "ebft"
|
||||||
and data.get("ebft", {}).get("mode") == "strided"
|
and (self.ebft or {}).get("mode") == "strided"
|
||||||
and data.get("activation_offloading") is True
|
and self.activation_offloading is True
|
||||||
and data.get("flex_attention")
|
and self.attn_implementation == "flex_attention"
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"EBFT strided mode: `activation_offloading: true` is incompatible with "
|
"EBFT strided mode: `activation_offloading: true` is incompatible with "
|
||||||
"`flex_attention: true`. Activation offloading replaces gradient checkpointing "
|
"`attn_implementation: flex_attention`. Activation offloading replaces "
|
||||||
"with FSDP-style wrapping that conflicts with flex_attention's reentrant "
|
"gradient checkpointing with FSDP-style wrapping that conflicts with "
|
||||||
"checkpoint requirement. Remove `activation_offloading` — the strided trainer "
|
"flex_attention's reentrant checkpoint requirement. Remove "
|
||||||
"uses micro-batched forward passes for memory efficiency instead."
|
"`activation_offloading` — the strided trainer uses micro-batched forward "
|
||||||
|
"passes for memory efficiency instead."
|
||||||
)
|
)
|
||||||
return data
|
return self
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user