From 208f8b253f76b035659554639ad7b55685528aaf Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 11 Aug 2025 21:28:07 -0400 Subject: [PATCH] add validation for DFT --- src/axolotl/utils/schemas/validation.py | 27 +++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index bf054d353..5768f311f 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -434,6 +434,18 @@ class TrainingValidationMixin: return data + @model_validator(mode="before") + @classmethod + def check_ao_optim_fsdp2_offload(cls, data): + if data.get("fsdp_config") and data.get("fsdp_config", {}).get( + "offload_params" + ): + if data.get("optimizer") in ["adamw_torch_8bit", "adamw_torch_4bit"]: + raise ValueError( + "low bit ao optimizers is not supported with FSDP2 w/ offload_params." + ) + return data + @model_validator(mode="before") @classmethod def check_use_reentrant_mismatch(cls, data): @@ -557,6 +569,20 @@ class TrainingValidationMixin: return data +class CELossValidationMixin: + """Validation methods related to CE loss configuration.""" + + @model_validator(mode="before") + @classmethod + def check_dft_loss_fn(cls, data): + if data.get("use_dynamic_finetuning"): + if not data.get("chunked_cross_entropy"): + raise ValueError( + "`use_dynamic_finetuning` requires `chunked_cross_entropy`" + ) + return data + + class LoRAValidationMixin: """Validation methods related to LoRA/QLoRA configuration.""" @@ -1464,6 +1490,7 @@ class ValidationMixin( DatasetValidationMixin, AttentionValidationMixin, TrainingValidationMixin, + CELossValidationMixin, LoRAValidationMixin, RLValidationMixin, OptimizationValidationMixin,