add validation for DFT
This commit is contained in:
@@ -434,6 +434,18 @@ class TrainingValidationMixin:
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_ao_optim_fsdp2_offload(cls, data):
|
||||
if data.get("fsdp_config") and data.get("fsdp_config", {}).get(
|
||||
"offload_params"
|
||||
):
|
||||
if data.get("optimizer") in ["adamw_torch_8bit", "adamw_torch_4bit"]:
|
||||
raise ValueError(
|
||||
"low bit ao optimizers is not supported with FSDP2 w/ offload_params."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_use_reentrant_mismatch(cls, data):
|
||||
@@ -557,6 +569,20 @@ class TrainingValidationMixin:
|
||||
return data
|
||||
|
||||
|
||||
class CELossValidationMixin:
|
||||
"""Validation methods related to CE loss configuration."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_dft_loss_fn(cls, data):
|
||||
if data.get("use_dynamic_finetuning"):
|
||||
if not data.get("chunked_cross_entropy"):
|
||||
raise ValueError(
|
||||
"`use_dynamic_finetuning` requires `chunked_cross_entropy`"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class LoRAValidationMixin:
|
||||
"""Validation methods related to LoRA/QLoRA configuration."""
|
||||
|
||||
@@ -1464,6 +1490,7 @@ class ValidationMixin(
|
||||
DatasetValidationMixin,
|
||||
AttentionValidationMixin,
|
||||
TrainingValidationMixin,
|
||||
CELossValidationMixin,
|
||||
LoRAValidationMixin,
|
||||
RLValidationMixin,
|
||||
OptimizationValidationMixin,
|
||||
|
||||
Reference in New Issue
Block a user