fix: fsdp_config validation being None (#3061) [skip ci]
* fix: fsdp_config validation being None * fix: handling --------- Co-authored-by: salman <salman.mohammadi@outlook.com>
This commit is contained in:
@@ -370,10 +370,10 @@ class TrainingValidationMixin:
|
||||
"see speed improvements. Please consider setting `torch_compile: "
|
||||
"true` in your config."
|
||||
)
|
||||
fsdp_config = data.get("fsdp_config") or {}
|
||||
if data.get("fp8") and (
|
||||
data.get("fsdp_config", {}).get("activation_checkpointing", False) is True
|
||||
or data.get("fsdp_config", {}).get("fsdp_activation_checkpointing", False)
|
||||
is True
|
||||
fsdp_config.get("activation_checkpointing", False) is True
|
||||
or fsdp_config.get("fsdp_activation_checkpointing", False) is True
|
||||
):
|
||||
LOG.warning(
|
||||
"FP8 + FSDP2 + activation checkpointing may be slower than BF16 "
|
||||
@@ -818,13 +818,13 @@ class OptimizationValidationMixin:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_version_in_fsdp_config(cls, data):
|
||||
if data.get("fsdp_config"):
|
||||
if data.get("fsdp_config", {}).get("fsdp_version"):
|
||||
LOG.warning(
|
||||
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
|
||||
"Please configure `fsdp_version` as a top-level field."
|
||||
)
|
||||
data["fsdp_version"] = data.get("fsdp_config").pop("fsdp_version")
|
||||
fsdp_config = data.get("fsdp_config") or {}
|
||||
if fsdp_config and fsdp_config.get("fsdp_version"):
|
||||
LOG.warning(
|
||||
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
|
||||
"Please configure `fsdp_version` as a top-level field."
|
||||
)
|
||||
data["fsdp_version"] = fsdp_config.pop("fsdp_version")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -1152,10 +1152,8 @@ class ModelCompatibilityValidationMixin:
|
||||
@classmethod
|
||||
def check_gpt_oss_fsdp_loading(cls, data):
|
||||
if data.get("model_quantization_config", "") == "Mxfp4Config":
|
||||
if (
|
||||
data.get("fsdp_config", {}).get("cpu_ram_efficient_loading", False)
|
||||
is True
|
||||
):
|
||||
fsdp_config = data.get("fsdp_config") or {}
|
||||
if fsdp_config.get("cpu_ram_efficient_loading", False) is True:
|
||||
raise ValueError(
|
||||
"FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user