From 506e3a39074a76df223af211af8e503343ea6b3e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 14 Aug 2025 08:21:50 +0700 Subject: [PATCH] fix: fsdp_config validation being None (#3061) [skip ci] * fix: fsdp_config validation being None * fix: handling --------- Co-authored-by: salman --- src/axolotl/utils/schemas/validation.py | 26 ++++++++++++------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 0d6d05a0e..217244b01 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -370,10 +370,10 @@ class TrainingValidationMixin: "see speed improvements. Please consider setting `torch_compile: " "true` in your config." ) + fsdp_config = data.get("fsdp_config") or {} if data.get("fp8") and ( - data.get("fsdp_config", {}).get("activation_checkpointing", False) is True - or data.get("fsdp_config", {}).get("fsdp_activation_checkpointing", False) - is True + fsdp_config.get("activation_checkpointing", False) is True + or fsdp_config.get("fsdp_activation_checkpointing", False) is True ): LOG.warning( "FP8 + FSDP2 + activation checkpointing may be slower than BF16 " @@ -818,13 +818,13 @@ class OptimizationValidationMixin: @model_validator(mode="before") @classmethod def check_fsdp_version_in_fsdp_config(cls, data): - if data.get("fsdp_config"): - if data.get("fsdp_config", {}).get("fsdp_version"): - LOG.warning( - "Configuring `fsdp_version` in `fsdp_config` is deprecated. " - "Please configure `fsdp_version` as a top-level field." - ) - data["fsdp_version"] = data.get("fsdp_config").pop("fsdp_version") + fsdp_config = data.get("fsdp_config") or {} + if fsdp_config and fsdp_config.get("fsdp_version"): + LOG.warning( + "Configuring `fsdp_version` in `fsdp_config` is deprecated. " + "Please configure `fsdp_version` as a top-level field." + ) + data["fsdp_version"] = fsdp_config.pop("fsdp_version") return data @model_validator(mode="before") @@ -1152,10 +1152,8 @@ class ModelCompatibilityValidationMixin: @classmethod def check_gpt_oss_fsdp_loading(cls, data): if data.get("model_quantization_config", "") == "Mxfp4Config": - if ( - data.get("fsdp_config", {}).get("cpu_ram_efficient_loading", False) - is True - ): + fsdp_config = data.get("fsdp_config") or {} + if fsdp_config.get("cpu_ram_efficient_loading", False) is True: raise ValueError( "FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization." )