FSDPConfig (#3170)

This commit is contained in:
salman
2025-10-10 14:44:25 +01:00
committed by GitHub
parent bc2ffb8204
commit 143dea4753
7 changed files with 75 additions and 18 deletions

View File

@@ -29,7 +29,7 @@ flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
save_strategy: no
torch_compile: true
wandb_project:

View File

@@ -560,13 +560,6 @@ class AxolotlTrainer(
super().create_accelerator_and_postprocess()
if self.is_fsdp_enabled:
if (
"limit_all_gathers" in self.args.fsdp_config
and self.args.fsdp_config["limit_all_gathers"]
):
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
def additional_accelerator_args(
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
) -> dict[str, Any]:

View File

@@ -24,6 +24,7 @@ from axolotl.utils.schemas.datasets import (
)
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
from axolotl.utils.schemas.fsdp import FSDPConfig
from axolotl.utils.schemas.integrations import (
CometConfig,
GradioConfig,
@@ -667,8 +668,7 @@ class AxolotlInputConfig(
json_schema_extra={"description": "FSDP configuration"},
deprecated="Configuring FSDP using `fsdp` is deprecated. Please use `fsdp_config` instead. ",
)
# TODO @SalmanMohammadi strongly type this as its own schema
fsdp_config: dict[str, Any] | None = Field(
fsdp_config: FSDPConfig | None = Field(
default=None, json_schema_extra={"description": "FSDP configuration options"}
)
fsdp_version: int | None = Field(

View File

@@ -0,0 +1,71 @@
"""
FSDP Configuration Schema
"""
from typing import Literal
from pydantic import BaseModel, Field
class FSDPConfig(BaseModel):
"""
FSDP Configuration Schema
"""
activation_checkpointing: bool | None = Field(
default=None,
description="Enable activation checkpointing to reduce memory usage during forward passes",
)
offload_params: bool | None = Field(
default=None,
description="Offload parameters to CPU to reduce GPU memory usage",
)
sync_module_states: bool | None = Field(
default=None,
description="Synchronize module states across all processes",
)
cpu_ram_efficient_loading: bool | None = Field(
default=None,
description="Enable CPU RAM efficient loading to reduce memory usage during model loading",
)
cpu_offload_pin_memory: bool | None = Field(
default=None,
description="Disabling this enables swap memory usage for resource-constrained setups when offload_params is enabled.",
)
use_orig_params: bool | None = Field(
default=None,
description="Use original parameters instead of flattened parameters",
)
state_dict_type: (
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
) = Field(
default=None,
description="Type of state dict to use for saving/loading checkpoints",
)
final_state_dict_type: (
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
) = Field(
default=None,
description="Final state dict type to use after training completion",
)
auto_wrap_policy: Literal["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP"] | None = (
Field(
default=None,
description="Policy for automatically wrapping modules with FSDP",
)
)
transformer_layer_cls_to_wrap: str | None = Field(
default=None,
description="Class name of transformer layers to wrap (e.g., 'LlamaDecoderLayer')",
)
reshard_after_forward: bool | None = Field(
default=None,
description="Reshard parameters after forward pass to save memory",
)
mixed_precision_policy: str | None = Field(
default=None,
description="Mixed precision policy for FSDP (e.g., 'fp16', 'bf16')",
)

View File

@@ -881,7 +881,7 @@ class OptimizationValidationMixin:
and self.fsdp_config
and self.optimizer
and "8bit" in self.optimizer.value
and self.fsdp_config["offload_params"]
and self.fsdp_config.offload_params
and str(self.fsdp_version) != "2"
):
raise ValueError(

View File

@@ -353,7 +353,6 @@ class TestMultiGPULlama:
"auto_wrap",
],
"fsdp_config": {
"fsdp_limit_all_gathers": True,
"fsdp_offload_params": False,
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,
@@ -431,7 +430,6 @@ class TestMultiGPULlama:
"auto_wrap",
],
"fsdp_config": {
"fsdp_limit_all_gathers": True,
"fsdp_offload_params": False,
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,
@@ -595,7 +593,6 @@ class TestMultiGPULlama:
"auto_wrap",
],
"fsdp_config": {
"fsdp_limit_all_gathers": True,
"fsdp_offload_params": False,
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,

View File

@@ -111,7 +111,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_offload_params": False,
"fsdp_cpu_ram_efficient_loading": True,
"regular_param": "value",
}
}
)
@@ -124,7 +123,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
)
self.assertEqual(cfg_with_version.fsdp_config.offload_params, False)
self.assertEqual(cfg_with_version.fsdp_config.cpu_ram_efficient_loading, True)
self.assertEqual(cfg_with_version.fsdp_config.regular_param, "value")
self.assertNotIn("fsdp_auto_wrap_policy", cfg_with_version.fsdp_config)
self.assertNotIn("fsdp_offload_params", cfg_with_version.fsdp_config)
@@ -137,7 +135,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
"fsdp_config": {
"fsdp_auto_wrap_policy": "SIZE_BASED_WRAP",
"fsdp_offload_params": True,
"regular_param": "value",
}
}
)
@@ -149,7 +146,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
cfg_without_version.fsdp_config.auto_wrap_policy, "SIZE_BASED_WRAP"
)
self.assertEqual(cfg_without_version.fsdp_config.offload_params, True)
self.assertEqual(cfg_without_version.fsdp_config.regular_param, "value")
self.assertNotIn("fsdp_auto_wrap_policy", cfg_without_version.fsdp_config)
self.assertNotIn("fsdp_offload_params", cfg_without_version.fsdp_config)