diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index fb9cf3bde..366576f0c 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1474,32 +1474,27 @@ class AxolotlInputConfig( f"path containing '/'." ) - @model_validator(mode="before") - @classmethod - def check_sageattn_wo_sample_packing(cls, data): - is_sage = ( - data.get("sage_attention") or data.get("attn_implementation") == "sage" - ) - if (not data.get("sample_packing", False)) and is_sage: - if not data.get("pad_to_sequence_len", False): - LOG.warning( - "We recommend turning on `pad_to_sequence_len` for SageAttention without packing." - "This is because there has been signs that the loss explodes after a few steps." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_sageattn_fft(cls, data): - is_sage = ( - data.get("sage_attention") or data.get("attn_implementation") == "sage" - ) - if (not data.get("adapter", False)) and is_sage: + @model_validator(mode="after") + def check_sageattn_wo_sample_packing(self): + if ( + self.attn_implementation == "sage" + and not self.sample_packing + and not self.pad_to_sequence_len + ): LOG.warning( - "We found loss to drop to 0 with SageAttention full finetuning." - "Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method." + "We recommend turning on `pad_to_sequence_len` for SageAttention " + "without packing. The loss has been observed to explode otherwise." ) - return data + return self + + @model_validator(mode="after") + def check_sageattn_fft(self): + if self.attn_implementation == "sage" and not self.adapter: + LOG.warning( + "SageAttention full finetuning has been observed to drop loss to 0. " + "Monitor the loss, or switch to LoRA/QLoRA or another attention method." + ) + return self @model_validator(mode="before") @classmethod @@ -1575,17 +1570,13 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): ) return self - @model_validator(mode="before") - @classmethod - def check_sample_packing_w_sdpa_bf16(cls, data): - is_sm_90: bool = ( - data["capabilities"] - and data["capabilities"].get("compute_capability") == "sm_90" - ) + @model_validator(mode="after") + def check_sample_packing_w_sdpa_bf16(self): + is_sm_90 = self.capabilities and self.capabilities.compute_capability == "sm_90" if ( - data.get("sample_packing") - and (data.get("sdp_attention") or data.get("attn_implementation") == "sdpa") - and (data.get("bfloat16") or data.get("bf16")) + self.sample_packing + and self.attn_implementation == "sdpa" + and (self.bfloat16 or self.bf16) and not is_sm_90 ): # https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450 @@ -1593,26 +1584,51 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): "sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. " "This may work on H100s." ) + return self - return data - - @model_validator(mode="before") - @classmethod - def check_compute_capability_w_sageattn(cls, data): - is_sage = ( - data.get("sage_attention") or data.get("attn_implementation") == "sage" - ) + @model_validator(mode="after") + def check_compute_capability_w_sageattn(self): if ( - is_sage - and data.get("capabilities") - and data.get("capabilities").get("compute_capability") + self.attn_implementation == "sage" + and self.capabilities + and self.capabilities.compute_capability not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"] ): raise ValueError( "SageAttention supports compute capability between sm_80 and sm_120. " "Please use a different attention implementation." ) - return data + return self + + @model_validator(mode="after") + def check_fp8_attention_preflight(self): + """fp8 attention requires SM90+ and torch >= 2.11 (torchao >= 0.17 is pinned).""" + if self.attn_implementation != "fp8": + return self + + if self.capabilities and self.capabilities.compute_capability: + cc = self.capabilities.compute_capability + # Accept sm_90 (H100/H200), sm_100 (B100/B200), sm_120 (B300-class). + if not cc.startswith("sm_") or int(cc.split("_", 1)[1]) < 90: + raise ValueError( + f"attn_implementation=fp8 requires compute capability sm_90 or " + f"higher (Hopper+). Detected {cc!r}." + ) + + torch_version = ( + self.env_capabilities.torch_version if self.env_capabilities else None + ) + if torch_version is None: + import torch + + torch_version = str(torch.__version__).split("+", maxsplit=1)[0] + if version.parse(torch_version) < version.parse("2.11.0"): + raise ValueError( + f"attn_implementation=fp8 requires PyTorch >= 2.11.0. " + f"Detected {torch_version}." + ) + + return self @model_validator(mode="before") @classmethod @@ -1768,16 +1784,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): ) return data - @model_validator(mode="before") - @classmethod - def check_flex_torch_version(cls, data): - if ( - data.get("flex_attention") - or data.get("attn_implementation") == "flex_attention" - ): - env_capabilities = data.get("env_capabilities", {}) - torch_version = env_capabilities.get("torch_version") - + @model_validator(mode="after") + def check_flex_torch_version(self): + if self.attn_implementation == "flex_attention": + torch_version = ( + self.env_capabilities.torch_version if self.env_capabilities else None + ) if torch_version is None: import torch @@ -1787,7 +1799,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): raise ValueError( "Flex attention is not supported on torch version < 2.6.0" ) - return data + return self @model_validator(mode="before") @classmethod diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 6f57da971..3b7b93986 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -13,7 +13,6 @@ from transformers.utils.import_utils import is_torch_npu_available from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import ( - ATTN_IMPLS_SUPPORTING_PACKING, ChatTemplate, RingAttnFunc, RLType, @@ -187,47 +186,42 @@ class AttentionValidationMixin: # `check_attention_fields` was removed — `AxolotlInputConfig.normalize_attn_implementation` # is now the single entry point for attention-input mapping and conflict detection. - @model_validator(mode="before") - @classmethod - def check_sample_packing_without_attention(cls, data): - if ( - data.get("sample_packing") - and not data.get("attn_implementation") - and not data.get("flash_attention") - and not data.get("sdp_attention") - and not data.get("flex_attention") - and not data.get("xformers_attention") - and not data.get("sage_attention") - ): - LOG.warning( - "sample_packing without flash, sdp, xformers, sage, or flex attention does not handle cross sample decontamination." - ) - return data + @model_validator(mode="after") + def check_sample_packing_without_attention(self): + if self.sample_packing and not self.attn_supports_packing: + if self.attn_implementation: + LOG.warning( + "`sample_packing` with `attn_implementation=%r` does not handle " + "cross-sample decontamination. Use a varlen-capable backend " + "(e.g. flash_attention_2, flex_attention, xformers, sage) to " + "isolate samples.", + self.attn_implementation, + ) + else: + LOG.warning( + "`sample_packing` without an attention backend does not handle " + "cross-sample decontamination. Set `attn_implementation` to a " + "varlen-capable backend (e.g. flash_attention_2)." + ) + return self - @model_validator(mode="before") - @classmethod - def check_sample_packing_with_s2attn(cls, data): - if data.get("sample_packing") and ( - data.get("s2_attention") or data.get("attn_implementation") == "s2" - ): + @model_validator(mode="after") + def check_sample_packing_with_s2attn(self): + if self.sample_packing and self.attn_implementation == "s2": raise ValueError( - "Received `sample_packing=true` and `s2_attention=true`; however, \ - shifted-sparse attention does not currently support sample packing." + "Received `sample_packing=true` and `attn_implementation=s2`; " + "shifted-sparse attention does not currently support sample packing." ) - return data + return self - @model_validator(mode="before") - @classmethod - def check_scaling_softmax_requires_flex(cls, data): - if data.get("scaling_softmax") and not ( - data.get("flex_attention") - or data.get("attn_implementation") == "flex_attention" - ): + @model_validator(mode="after") + def check_scaling_softmax_requires_flex(self): + if self.scaling_softmax and self.attn_implementation != "flex_attention": raise ValueError( - "scaling_softmax requires flex attention.\n" - "Add 'attn_implementation: flex' to your config file.\n" + "scaling_softmax requires flex attention. " + "Add `attn_implementation: flex_attention` to your config." ) - return data + return self class TrainingValidationMixin: @@ -933,48 +927,45 @@ class OptimizationValidationMixin: ) return data - @model_validator(mode="before") - @classmethod - def check_batch_flattening_fa(cls, data): - if data.get("batch_flattening"): - batch_flattening_auto = data.get("batch_flattening") == "auto" - has_varlen_attn = ( - data.get("attn_implementation") in ATTN_IMPLS_SUPPORTING_PACKING - if data.get("attn_implementation") - else data.get("flash_attention") + @model_validator(mode="after") + def check_batch_flattening_fa(self): + if not self.batch_flattening: + return self + + batch_flattening_auto = self.batch_flattening == "auto" + has_varlen_attn = self.attn_supports_packing + + if not has_varlen_attn and not batch_flattening_auto: + raise ValueError( + "batch_flattening requires a varlen-capable attention backend " + "(e.g., attn_implementation: flash_attention_2)." ) - if not has_varlen_attn and not batch_flattening_auto: - raise ValueError( - "batch_flattening requires a varlen-capable attention backend " - "(e.g., attn_implementation: flash)" - ) - if data.get("sample_packing") and not batch_flattening_auto: - raise ValueError("batch_flattening not compatible with sample_packing") - if data.get("micro_batch_size") == 1 and not batch_flattening_auto: - LOG.warning("batch_flattening has no effect with micro_batch_size == 1") + if self.sample_packing and not batch_flattening_auto: + raise ValueError("batch_flattening not compatible with sample_packing") + if self.micro_batch_size == 1 and not batch_flattening_auto: + LOG.warning("batch_flattening has no effect with micro_batch_size == 1") - # Liger loss takes a separate code path (compute_liger_loss) that - # bypasses the flattened training forward pass. Batch flattening - # still applies to the scoring/deferred logprobs path. - trl_cfg = data.get("trl") or {} - if isinstance(trl_cfg, dict) and trl_cfg.get("use_liger_loss"): - LOG.warning( - "batch_flattening with use_liger_loss: flattening will only " - "apply to the scoring path (deferred logprobs). The training " - "forward pass uses Liger's fused lm_head+loss kernel instead." - ) + # Liger loss takes a separate code path (compute_liger_loss) that + # bypasses the flattened training forward pass. Batch flattening + # still applies to the scoring/deferred logprobs path. + if self.trl and getattr(self.trl, "use_liger_loss", False): + LOG.warning( + "batch_flattening with use_liger_loss: flattening will only " + "apply to the scoring path (deferred logprobs). The training " + "forward pass uses Liger's fused lm_head+loss kernel instead." + ) - if ( - batch_flattening_auto - and has_varlen_attn - and not data.get("sample_packing") - and data.get("micro_batch_size") > 1 - ): - data["batch_flattening"] = True - elif batch_flattening_auto: - data["batch_flattening"] = False + if ( + batch_flattening_auto + and has_varlen_attn + and not self.sample_packing + and self.micro_batch_size > 1 + ): + self.batch_flattening = True + elif batch_flattening_auto: + self.batch_flattening = False - return data + return self @model_validator(mode="before") @classmethod @@ -1211,12 +1202,18 @@ class SystemValidationMixin: def check_npu_config(cls, data): if is_torch_npu_available(): # check attention config - unsupported_npu_impls = {"flash", "sdpa", "s2"} + unsupported_npu_impls = { + "flash_attention_2", + "flash_attention_3", + "sdpa", + "s2", + } attn_impl = data.get("attn_implementation") if attn_impl and attn_impl in unsupported_npu_impls: raise NotImplementedError( f"attn_implementation={attn_impl!r} is currently not supported on Ascend NPU." ) + # Legacy flags still present at this point (normalizer strips them later). attn_list = ["flash_attention", "sdp_attention", "s2_attention"] for attn in attn_list: if data.get(attn): @@ -1658,50 +1655,46 @@ class EBFTValidationMixin: ) return data - @model_validator(mode="before") - @classmethod - def check_ebft_gradient_checkpointing_reentrant(cls, data): + @model_validator(mode="after") + def check_ebft_gradient_checkpointing_reentrant(self): """flex_attention + non-reentrant gradient checkpointing causes CheckpointError.""" if ( - data.get("rl") == "ebft" - and data.get("ebft", {}).get("mode") == "strided" - and ( - data.get("flex_attention") - or data.get("attn_implementation") == "flex_attention" - ) - and data.get("gradient_checkpointing") + self.rl == "ebft" + and (self.ebft or {}).get("mode") == "strided" + and self.attn_implementation == "flex_attention" + and self.gradient_checkpointing ): - gc_kwargs = data.get("gradient_checkpointing_kwargs") or {} + gc_kwargs = self.gradient_checkpointing_kwargs or {} if not gc_kwargs.get("use_reentrant"): LOG.warning( "EBFT strided mode with flex_attention: setting `use_reentrant: true` in " "gradient_checkpointing_kwargs (required for flex_attention compatibility). " "Non-reentrant checkpointing causes CheckpointError with BlockMask metadata." ) - if data.get("gradient_checkpointing_kwargs") is None: - data["gradient_checkpointing_kwargs"] = {} - data["gradient_checkpointing_kwargs"]["use_reentrant"] = True - return data + if self.gradient_checkpointing_kwargs is None: + self.gradient_checkpointing_kwargs = {} + self.gradient_checkpointing_kwargs["use_reentrant"] = True + return self - @model_validator(mode="before") - @classmethod - def check_ebft_activation_offloading(cls, data): + @model_validator(mode="after") + def check_ebft_activation_offloading(self): """activation_offloading replaces gradient checkpointing with FSDP-style wrapping, which conflicts with flex_attention's use_reentrant requirement.""" if ( - data.get("rl") == "ebft" - and data.get("ebft", {}).get("mode") == "strided" - and data.get("activation_offloading") is True - and data.get("flex_attention") + self.rl == "ebft" + and (self.ebft or {}).get("mode") == "strided" + and self.activation_offloading is True + and self.attn_implementation == "flex_attention" ): raise ValueError( "EBFT strided mode: `activation_offloading: true` is incompatible with " - "`flex_attention: true`. Activation offloading replaces gradient checkpointing " - "with FSDP-style wrapping that conflicts with flex_attention's reentrant " - "checkpoint requirement. Remove `activation_offloading` — the strided trainer " - "uses micro-batched forward passes for memory efficiency instead." + "`attn_implementation: flex_attention`. Activation offloading replaces " + "gradient checkpointing with FSDP-style wrapping that conflicts with " + "flex_attention's reentrant checkpoint requirement. Remove " + "`activation_offloading` — the strided trainer uses micro-batched forward " + "passes for memory efficiency instead." ) - return data + return self @model_validator(mode="before") @classmethod