fix: fsdp_config validation being None (#3061) [skip ci]
* fix: fsdp_config validation being None * fix: handling --------- Co-authored-by: salman <salman.mohammadi@outlook.com>
This commit is contained in:
@@ -370,10 +370,10 @@ class TrainingValidationMixin:
|
|||||||
"see speed improvements. Please consider setting `torch_compile: "
|
"see speed improvements. Please consider setting `torch_compile: "
|
||||||
"true` in your config."
|
"true` in your config."
|
||||||
)
|
)
|
||||||
|
fsdp_config = data.get("fsdp_config") or {}
|
||||||
if data.get("fp8") and (
|
if data.get("fp8") and (
|
||||||
data.get("fsdp_config", {}).get("activation_checkpointing", False) is True
|
fsdp_config.get("activation_checkpointing", False) is True
|
||||||
or data.get("fsdp_config", {}).get("fsdp_activation_checkpointing", False)
|
or fsdp_config.get("fsdp_activation_checkpointing", False) is True
|
||||||
is True
|
|
||||||
):
|
):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"FP8 + FSDP2 + activation checkpointing may be slower than BF16 "
|
"FP8 + FSDP2 + activation checkpointing may be slower than BF16 "
|
||||||
@@ -818,13 +818,13 @@ class OptimizationValidationMixin:
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_fsdp_version_in_fsdp_config(cls, data):
|
def check_fsdp_version_in_fsdp_config(cls, data):
|
||||||
if data.get("fsdp_config"):
|
fsdp_config = data.get("fsdp_config") or {}
|
||||||
if data.get("fsdp_config", {}).get("fsdp_version"):
|
if fsdp_config and fsdp_config.get("fsdp_version"):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
|
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
|
||||||
"Please configure `fsdp_version` as a top-level field."
|
"Please configure `fsdp_version` as a top-level field."
|
||||||
)
|
)
|
||||||
data["fsdp_version"] = data.get("fsdp_config").pop("fsdp_version")
|
data["fsdp_version"] = fsdp_config.pop("fsdp_version")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -1152,10 +1152,8 @@ class ModelCompatibilityValidationMixin:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def check_gpt_oss_fsdp_loading(cls, data):
|
def check_gpt_oss_fsdp_loading(cls, data):
|
||||||
if data.get("model_quantization_config", "") == "Mxfp4Config":
|
if data.get("model_quantization_config", "") == "Mxfp4Config":
|
||||||
if (
|
fsdp_config = data.get("fsdp_config") or {}
|
||||||
data.get("fsdp_config", {}).get("cpu_ram_efficient_loading", False)
|
if fsdp_config.get("cpu_ram_efficient_loading", False) is True:
|
||||||
is True
|
|
||||||
):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization."
|
"FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization."
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user