add validation for DFT

This commit is contained in:
Wing Lian
2025-08-11 21:28:07 -04:00
parent 75ad1a9932
commit 208f8b253f

View File

@@ -434,6 +434,18 @@ class TrainingValidationMixin:
return data
@model_validator(mode="before")
@classmethod
def check_ao_optim_fsdp2_offload(cls, data):
if data.get("fsdp_config") and data.get("fsdp_config", {}).get(
"offload_params"
):
if data.get("optimizer") in ["adamw_torch_8bit", "adamw_torch_4bit"]:
raise ValueError(
"low bit ao optimizers is not supported with FSDP2 w/ offload_params."
)
return data
@model_validator(mode="before")
@classmethod
def check_use_reentrant_mismatch(cls, data):
@@ -557,6 +569,20 @@ class TrainingValidationMixin:
return data
class CELossValidationMixin:
"""Validation methods related to CE loss configuration."""
@model_validator(mode="before")
@classmethod
def check_dft_loss_fn(cls, data):
if data.get("use_dynamic_finetuning"):
if not data.get("chunked_cross_entropy"):
raise ValueError(
"`use_dynamic_finetuning` requires `chunked_cross_entropy`"
)
return data
class LoRAValidationMixin:
"""Validation methods related to LoRA/QLoRA configuration."""
@@ -1464,6 +1490,7 @@ class ValidationMixin(
DatasetValidationMixin,
AttentionValidationMixin,
TrainingValidationMixin,
CELossValidationMixin,
LoRAValidationMixin,
RLValidationMixin,
OptimizationValidationMixin,