fix: fsdp_config validation being None (#3061) [skip ci]

* fix: fsdp_config validation being None

* fix: handling

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
This commit is contained in:
NanoCode012
2025-08-14 08:21:50 +07:00
committed by GitHub
parent 09145de8fa
commit 506e3a3907

View File

@@ -370,10 +370,10 @@ class TrainingValidationMixin:
"see speed improvements. Please consider setting `torch_compile: " "see speed improvements. Please consider setting `torch_compile: "
"true` in your config." "true` in your config."
) )
fsdp_config = data.get("fsdp_config") or {}
if data.get("fp8") and ( if data.get("fp8") and (
data.get("fsdp_config", {}).get("activation_checkpointing", False) is True fsdp_config.get("activation_checkpointing", False) is True
or data.get("fsdp_config", {}).get("fsdp_activation_checkpointing", False) or fsdp_config.get("fsdp_activation_checkpointing", False) is True
is True
): ):
LOG.warning( LOG.warning(
"FP8 + FSDP2 + activation checkpointing may be slower than BF16 " "FP8 + FSDP2 + activation checkpointing may be slower than BF16 "
@@ -818,13 +818,13 @@ class OptimizationValidationMixin:
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_fsdp_version_in_fsdp_config(cls, data): def check_fsdp_version_in_fsdp_config(cls, data):
if data.get("fsdp_config"): fsdp_config = data.get("fsdp_config") or {}
if data.get("fsdp_config", {}).get("fsdp_version"): if fsdp_config and fsdp_config.get("fsdp_version"):
LOG.warning( LOG.warning(
"Configuring `fsdp_version` in `fsdp_config` is deprecated. " "Configuring `fsdp_version` in `fsdp_config` is deprecated. "
"Please configure `fsdp_version` as a top-level field." "Please configure `fsdp_version` as a top-level field."
) )
data["fsdp_version"] = data.get("fsdp_config").pop("fsdp_version") data["fsdp_version"] = fsdp_config.pop("fsdp_version")
return data return data
@model_validator(mode="before") @model_validator(mode="before")
@@ -1152,10 +1152,8 @@ class ModelCompatibilityValidationMixin:
@classmethod @classmethod
def check_gpt_oss_fsdp_loading(cls, data): def check_gpt_oss_fsdp_loading(cls, data):
if data.get("model_quantization_config", "") == "Mxfp4Config": if data.get("model_quantization_config", "") == "Mxfp4Config":
if ( fsdp_config = data.get("fsdp_config") or {}
data.get("fsdp_config", {}).get("cpu_ram_efficient_loading", False) if fsdp_config.get("cpu_ram_efficient_loading", False) is True:
is True
):
raise ValueError( raise ValueError(
"FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization." "FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization."
) )