FSDPConfig (#3170)
This commit is contained in:
@@ -29,7 +29,7 @@ flex_attention: true
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
|
||||
save_strategy: no
|
||||
torch_compile: true
|
||||
|
||||
wandb_project:
|
||||
|
||||
@@ -560,13 +560,6 @@ class AxolotlTrainer(
|
||||
|
||||
super().create_accelerator_and_postprocess()
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
if (
|
||||
"limit_all_gathers" in self.args.fsdp_config
|
||||
and self.args.fsdp_config["limit_all_gathers"]
|
||||
):
|
||||
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
|
||||
|
||||
def additional_accelerator_args(
|
||||
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
|
||||
@@ -24,6 +24,7 @@ from axolotl.utils.schemas.datasets import (
|
||||
)
|
||||
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
|
||||
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
|
||||
from axolotl.utils.schemas.fsdp import FSDPConfig
|
||||
from axolotl.utils.schemas.integrations import (
|
||||
CometConfig,
|
||||
GradioConfig,
|
||||
@@ -667,8 +668,7 @@ class AxolotlInputConfig(
|
||||
json_schema_extra={"description": "FSDP configuration"},
|
||||
deprecated="Configuring FSDP using `fsdp` is deprecated. Please use `fsdp_config` instead. ",
|
||||
)
|
||||
# TODO @SalmanMohammadi strongly type this as its own schema
|
||||
fsdp_config: dict[str, Any] | None = Field(
|
||||
fsdp_config: FSDPConfig | None = Field(
|
||||
default=None, json_schema_extra={"description": "FSDP configuration options"}
|
||||
)
|
||||
fsdp_version: int | None = Field(
|
||||
|
||||
71
src/axolotl/utils/schemas/fsdp.py
Normal file
71
src/axolotl/utils/schemas/fsdp.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
FSDP Configuration Schema
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FSDPConfig(BaseModel):
|
||||
"""
|
||||
FSDP Configuration Schema
|
||||
"""
|
||||
|
||||
activation_checkpointing: bool | None = Field(
|
||||
default=None,
|
||||
description="Enable activation checkpointing to reduce memory usage during forward passes",
|
||||
)
|
||||
offload_params: bool | None = Field(
|
||||
default=None,
|
||||
description="Offload parameters to CPU to reduce GPU memory usage",
|
||||
)
|
||||
sync_module_states: bool | None = Field(
|
||||
default=None,
|
||||
description="Synchronize module states across all processes",
|
||||
)
|
||||
cpu_ram_efficient_loading: bool | None = Field(
|
||||
default=None,
|
||||
description="Enable CPU RAM efficient loading to reduce memory usage during model loading",
|
||||
)
|
||||
cpu_offload_pin_memory: bool | None = Field(
|
||||
default=None,
|
||||
description="Disabling this enables swap memory usage for resource-constrained setups when offload_params is enabled.",
|
||||
)
|
||||
use_orig_params: bool | None = Field(
|
||||
default=None,
|
||||
description="Use original parameters instead of flattened parameters",
|
||||
)
|
||||
|
||||
state_dict_type: (
|
||||
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
|
||||
) = Field(
|
||||
default=None,
|
||||
description="Type of state dict to use for saving/loading checkpoints",
|
||||
)
|
||||
final_state_dict_type: (
|
||||
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
|
||||
) = Field(
|
||||
default=None,
|
||||
description="Final state dict type to use after training completion",
|
||||
)
|
||||
|
||||
auto_wrap_policy: Literal["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP"] | None = (
|
||||
Field(
|
||||
default=None,
|
||||
description="Policy for automatically wrapping modules with FSDP",
|
||||
)
|
||||
)
|
||||
transformer_layer_cls_to_wrap: str | None = Field(
|
||||
default=None,
|
||||
description="Class name of transformer layers to wrap (e.g., 'LlamaDecoderLayer')",
|
||||
)
|
||||
|
||||
reshard_after_forward: bool | None = Field(
|
||||
default=None,
|
||||
description="Reshard parameters after forward pass to save memory",
|
||||
)
|
||||
mixed_precision_policy: str | None = Field(
|
||||
default=None,
|
||||
description="Mixed precision policy for FSDP (e.g., 'fp16', 'bf16')",
|
||||
)
|
||||
@@ -881,7 +881,7 @@ class OptimizationValidationMixin:
|
||||
and self.fsdp_config
|
||||
and self.optimizer
|
||||
and "8bit" in self.optimizer.value
|
||||
and self.fsdp_config["offload_params"]
|
||||
and self.fsdp_config.offload_params
|
||||
and str(self.fsdp_version) != "2"
|
||||
):
|
||||
raise ValueError(
|
||||
|
||||
@@ -353,7 +353,6 @@ class TestMultiGPULlama:
|
||||
"auto_wrap",
|
||||
],
|
||||
"fsdp_config": {
|
||||
"fsdp_limit_all_gathers": True,
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_use_orig_params": False,
|
||||
@@ -431,7 +430,6 @@ class TestMultiGPULlama:
|
||||
"auto_wrap",
|
||||
],
|
||||
"fsdp_config": {
|
||||
"fsdp_limit_all_gathers": True,
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_use_orig_params": False,
|
||||
@@ -595,7 +593,6 @@ class TestMultiGPULlama:
|
||||
"auto_wrap",
|
||||
],
|
||||
"fsdp_config": {
|
||||
"fsdp_limit_all_gathers": True,
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_use_orig_params": False,
|
||||
|
||||
@@ -111,7 +111,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": True,
|
||||
"regular_param": "value",
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -124,7 +123,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(cfg_with_version.fsdp_config.offload_params, False)
|
||||
self.assertEqual(cfg_with_version.fsdp_config.cpu_ram_efficient_loading, True)
|
||||
self.assertEqual(cfg_with_version.fsdp_config.regular_param, "value")
|
||||
|
||||
self.assertNotIn("fsdp_auto_wrap_policy", cfg_with_version.fsdp_config)
|
||||
self.assertNotIn("fsdp_offload_params", cfg_with_version.fsdp_config)
|
||||
@@ -137,7 +135,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
"fsdp_config": {
|
||||
"fsdp_auto_wrap_policy": "SIZE_BASED_WRAP",
|
||||
"fsdp_offload_params": True,
|
||||
"regular_param": "value",
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -149,7 +146,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
cfg_without_version.fsdp_config.auto_wrap_policy, "SIZE_BASED_WRAP"
|
||||
)
|
||||
self.assertEqual(cfg_without_version.fsdp_config.offload_params, True)
|
||||
self.assertEqual(cfg_without_version.fsdp_config.regular_param, "value")
|
||||
|
||||
self.assertNotIn("fsdp_auto_wrap_policy", cfg_without_version.fsdp_config)
|
||||
self.assertNotIn("fsdp_offload_params", cfg_without_version.fsdp_config)
|
||||
|
||||
Reference in New Issue
Block a user