From 143dea4753fe4a9ff5d9ef0f303e41a32091e355 Mon Sep 17 00:00:00 2001 From: salman Date: Fri, 10 Oct 2025 14:44:25 +0100 Subject: [PATCH] `FSDPConfig` (#3170) --- examples/llama-3/3b-fp8-fsdp2.yaml | 2 +- src/axolotl/core/trainers/base.py | 7 --- src/axolotl/utils/schemas/config.py | 4 +- src/axolotl/utils/schemas/fsdp.py | 71 +++++++++++++++++++++++++ src/axolotl/utils/schemas/validation.py | 2 +- tests/e2e/multigpu/test_llama.py | 3 -- tests/test_normalize_config.py | 4 -- 7 files changed, 75 insertions(+), 18 deletions(-) create mode 100644 src/axolotl/utils/schemas/fsdp.py diff --git a/examples/llama-3/3b-fp8-fsdp2.yaml b/examples/llama-3/3b-fp8-fsdp2.yaml index bea698c0e..b7de7ca52 100644 --- a/examples/llama-3/3b-fp8-fsdp2.yaml +++ b/examples/llama-3/3b-fp8-fsdp2.yaml @@ -29,7 +29,7 @@ flex_attention: true flex_attn_compile_kwargs: dynamic: false mode: max-autotune-no-cudagraphs - +save_strategy: no torch_compile: true wandb_project: diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 627f8e3f8..11dfecb98 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -560,13 +560,6 @@ class AxolotlTrainer( super().create_accelerator_and_postprocess() - if self.is_fsdp_enabled: - if ( - "limit_all_gathers" in self.args.fsdp_config - and self.args.fsdp_config["limit_all_gathers"] - ): - self.accelerator.state.fsdp_plugin.limit_all_gathers = True - def additional_accelerator_args( self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs ) -> dict[str, Any]: diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 0177b19f6..7cf8c3b4a 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -24,6 +24,7 @@ from axolotl.utils.schemas.datasets import ( ) from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType +from axolotl.utils.schemas.fsdp import FSDPConfig from axolotl.utils.schemas.integrations import ( CometConfig, GradioConfig, @@ -667,8 +668,7 @@ class AxolotlInputConfig( json_schema_extra={"description": "FSDP configuration"}, deprecated="Configuring FSDP using `fsdp` is deprecated. Please use `fsdp_config` instead. ", ) - # TODO @SalmanMohammadi strongly type this as its own schema - fsdp_config: dict[str, Any] | None = Field( + fsdp_config: FSDPConfig | None = Field( default=None, json_schema_extra={"description": "FSDP configuration options"} ) fsdp_version: int | None = Field( diff --git a/src/axolotl/utils/schemas/fsdp.py b/src/axolotl/utils/schemas/fsdp.py new file mode 100644 index 000000000..f34f40e8e --- /dev/null +++ b/src/axolotl/utils/schemas/fsdp.py @@ -0,0 +1,71 @@ +""" +FSDP Configuration Schema +""" + +from typing import Literal + +from pydantic import BaseModel, Field + + +class FSDPConfig(BaseModel): + """ + FSDP Configuration Schema + """ + + activation_checkpointing: bool | None = Field( + default=None, + description="Enable activation checkpointing to reduce memory usage during forward passes", + ) + offload_params: bool | None = Field( + default=None, + description="Offload parameters to CPU to reduce GPU memory usage", + ) + sync_module_states: bool | None = Field( + default=None, + description="Synchronize module states across all processes", + ) + cpu_ram_efficient_loading: bool | None = Field( + default=None, + description="Enable CPU RAM efficient loading to reduce memory usage during model loading", + ) + cpu_offload_pin_memory: bool | None = Field( + default=None, + description="Disabling this enables swap memory usage for resource-constrained setups when offload_params is enabled.", + ) + use_orig_params: bool | None = Field( + default=None, + description="Use original parameters instead of flattened parameters", + ) + + state_dict_type: ( + Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None + ) = Field( + default=None, + description="Type of state dict to use for saving/loading checkpoints", + ) + final_state_dict_type: ( + Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None + ) = Field( + default=None, + description="Final state dict type to use after training completion", + ) + + auto_wrap_policy: Literal["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP"] | None = ( + Field( + default=None, + description="Policy for automatically wrapping modules with FSDP", + ) + ) + transformer_layer_cls_to_wrap: str | None = Field( + default=None, + description="Class name of transformer layers to wrap (e.g., 'LlamaDecoderLayer')", + ) + + reshard_after_forward: bool | None = Field( + default=None, + description="Reshard parameters after forward pass to save memory", + ) + mixed_precision_policy: str | None = Field( + default=None, + description="Mixed precision policy for FSDP (e.g., 'fp16', 'bf16')", + ) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 4abe45e64..368976831 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -881,7 +881,7 @@ class OptimizationValidationMixin: and self.fsdp_config and self.optimizer and "8bit" in self.optimizer.value - and self.fsdp_config["offload_params"] + and self.fsdp_config.offload_params and str(self.fsdp_version) != "2" ): raise ValueError( diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index b836291e5..ffdbad942 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -353,7 +353,6 @@ class TestMultiGPULlama: "auto_wrap", ], "fsdp_config": { - "fsdp_limit_all_gathers": True, "fsdp_offload_params": False, "fsdp_sync_module_states": True, "fsdp_use_orig_params": False, @@ -431,7 +430,6 @@ class TestMultiGPULlama: "auto_wrap", ], "fsdp_config": { - "fsdp_limit_all_gathers": True, "fsdp_offload_params": False, "fsdp_sync_module_states": True, "fsdp_use_orig_params": False, @@ -595,7 +593,6 @@ class TestMultiGPULlama: "auto_wrap", ], "fsdp_config": { - "fsdp_limit_all_gathers": True, "fsdp_offload_params": False, "fsdp_sync_module_states": True, "fsdp_use_orig_params": False, diff --git a/tests/test_normalize_config.py b/tests/test_normalize_config.py index 658e06fcb..f0d3a2d72 100644 --- a/tests/test_normalize_config.py +++ b/tests/test_normalize_config.py @@ -111,7 +111,6 @@ class NormalizeConfigTestCase(unittest.TestCase): "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", "fsdp_offload_params": False, "fsdp_cpu_ram_efficient_loading": True, - "regular_param": "value", } } ) @@ -124,7 +123,6 @@ class NormalizeConfigTestCase(unittest.TestCase): ) self.assertEqual(cfg_with_version.fsdp_config.offload_params, False) self.assertEqual(cfg_with_version.fsdp_config.cpu_ram_efficient_loading, True) - self.assertEqual(cfg_with_version.fsdp_config.regular_param, "value") self.assertNotIn("fsdp_auto_wrap_policy", cfg_with_version.fsdp_config) self.assertNotIn("fsdp_offload_params", cfg_with_version.fsdp_config) @@ -137,7 +135,6 @@ class NormalizeConfigTestCase(unittest.TestCase): "fsdp_config": { "fsdp_auto_wrap_policy": "SIZE_BASED_WRAP", "fsdp_offload_params": True, - "regular_param": "value", } } ) @@ -149,7 +146,6 @@ class NormalizeConfigTestCase(unittest.TestCase): cfg_without_version.fsdp_config.auto_wrap_policy, "SIZE_BASED_WRAP" ) self.assertEqual(cfg_without_version.fsdp_config.offload_params, True) - self.assertEqual(cfg_without_version.fsdp_config.regular_param, "value") self.assertNotIn("fsdp_auto_wrap_policy", cfg_without_version.fsdp_config) self.assertNotIn("fsdp_offload_params", cfg_without_version.fsdp_config)