FSDP2 fix validation and add tests (#2910)

* fix validation and add tests

* remove debugging and add more tests

* remove migrate_fsdp
This commit is contained in:
Wing Lian
2025-07-14 09:25:44 -04:00
committed by GitHub
parent 80dc4c261a
commit af92151a7b
6 changed files with 283 additions and 119 deletions

View File

@@ -16,7 +16,6 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
migrate_fsdp_config,
normalize_cfg_datasets,
normalize_config,
validate_config,
@@ -227,7 +226,6 @@ def load_cfg(
},
)
migrate_fsdp_config(cfg)
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg)

View File

@@ -314,16 +314,3 @@ def prepare_plugins(cfg):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)
# TODO @SalmanMohammadi remove this function in 0.12
def migrate_fsdp_config(cfg):
if cfg.get("fsdp_config"):
fsdp_config_keys = cfg.fsdp_config.keys()
if "fsdp_version" in fsdp_config_keys:
cfg.fsdp_version = cfg.fsdp_config.pop("fsdp_version")
for key in list(fsdp_config_keys):
if key.startswith("fsdp_") and key != "fsdp_version":
cfg.fsdp_config[key.replace("fsdp_", "")] = cfg.fsdp_config[key]
del cfg.fsdp_config[key]

View File

@@ -1143,72 +1143,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_version(cls, data):
fsdp_config = data.get("fsdp_config", {})
if fsdp_config and str(data.get("fsdp_version")) != "2":
LOG.info(
"FSDP1 will be deprecated in an upcoming release of Axolotl."
"We recommend that you use FSDP version 2 for better performance and compatibility. "
"Please see this link for more details: https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp "
"For more details on migrating your config. "
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp2_base_model_quant_ram_efficient_loading(cls, data):
fsdp_config = data.get("fsdp_config")
if fsdp_config and data.get("fsdp_version") == 2:
if fsdp_config.get("cpu_ram_efficient_loading") and (
data.get("load_in_8bit") or data.get("load_in_4bit")
):
raise ValueError(
"FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
"set fsdp_version to 1, or disable cpu_ram_efficient_loading."
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp2_base_model_quant_dpo(cls, data):
if data.get("fsdp_version") == 2 and data.get("rl") in [
RLType.DPO,
RLType.KTO,
RLType.ORPO,
RLType.IPO,
]:
if data.get("load_in_8bit") or data.get("load_in_4bit"):
raise ValueError(
"FSDP2 does not support load_in_8bit or load_in_4bit with DPO. Please use DeepSpeed or set `fsdp_version` to 1."
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_version_in_fsdp_config(cls, data):
if fsdp_config := data.get("fsdp_config"):
if fsdp_config.get("fsdp_version"):
LOG.warning(
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
"Please configure `fsdp_version` as a top-level field."
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_config_kwargs_prefix(cls, data):
if fsdp_config := data.get("fsdp_config"):
for key, _ in fsdp_config.items():
if key.startswith("fsdp_"):
LOG.warning_once(
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
)
return data
@model_validator(mode="before")
@classmethod
def default_dataloader_opts(cls, data):

View File

@@ -1,6 +1,6 @@
"""Module with validation methods for config pydantic model."""
# pylint: disable=too-many-lines
# pylint: disable=too-many-lines,too-many-boolean-expressions
import logging
@@ -748,44 +748,128 @@ class OptimizationValidationMixin:
@model_validator(mode="before")
@classmethod
def check_fsdp_offload_w_8bit_optimizer(cls, data):
if (
data.get("fsdp")
and "8bit" in data.get("optimizer", "")
and data.get("fsdp_config")
and data["fsdp_config"].get("fsdp_offload_params")
and str(data["fsdp_config"].get("fsdp_version")) != "2"
):
raise ValueError(
f"FSDP Offload not compatible with {data.get('optimizer')}"
def check_fsdp_version(cls, data):
fsdp_config = data.get("fsdp_config", {})
if fsdp_config and str(data.get("fsdp_version")) != "2":
LOG.info(
"FSDP1 will be deprecated in an upcoming release of Axolotl."
"We recommend that you use FSDP version 2 for better performance and compatibility. "
"Please see this link for more details: https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp "
"For more details on migrating your config. "
)
if (
data.get("fsdp")
and "8bit" in data.get("optimizer", "")
and data.get("fsdp_config")
and str(data["fsdp_config"].get("fsdp_version")) == "2"
):
if data.get("optimizer", "") in ["adamw_8bit", "adamw_bnb_8bit"]:
# CUDA ops errors with bnb 8bit optimizer + FSDP2
return data
@model_validator(mode="after")
def check_fsdp2_base_model_quant_ram_efficient_loading(self):
fsdp_config = self.fsdp_config if hasattr(self, "fsdp_config") else None
fsdp_version = self.fsdp_version if hasattr(self, "fsdp_version") else None
load_in_8bit = self.load_in_8bit if hasattr(self, "load_in_8bit") else None
load_in_4bit = self.load_in_4bit if hasattr(self, "load_in_4bit") else None
if fsdp_config and fsdp_version == 2:
if fsdp_config.get("cpu_ram_efficient_loading") and (
load_in_8bit or load_in_4bit
):
raise ValueError(
f"FSDP2 not compatible with {data.get('optimizer')}, use `adamw_torch_8bit` instead"
"FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
"set fsdp_version to 1, or disable cpu_ram_efficient_loading."
)
return self
@model_validator(mode="before")
@classmethod
def check_fsdp2_base_model_quant_rl(cls, data):
if data.get("fsdp_version") == 2 and data.get("rl") in [
RLType.DPO,
RLType.KTO,
RLType.ORPO,
RLType.IPO,
]:
if data.get("load_in_8bit") or data.get("load_in_4bit"):
raise ValueError(
f"FSDP2 does not support load_in_8bit or load_in_4bit with {data.get('rl')}. Please use DeepSpeed or set `fsdp_version` to 1."
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_sharded_state_dict_w_safetensors(cls, data):
def check_fsdp_version_in_fsdp_config(cls, data):
if data.get("fsdp_config"):
if data.get("fsdp_config", {}).get("fsdp_version"):
LOG.warning(
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
"Please configure `fsdp_version` as a top-level field."
)
data["fsdp_version"] = data.get("fsdp_config").pop("fsdp_version")
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_config_kwargs_prefix(cls, data):
if fsdp_config := data.get("fsdp_config"):
should_fix = False
for key, _ in fsdp_config.items():
if key.startswith("fsdp_"):
should_fix = True
LOG.warning_once(
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
)
if should_fix:
update_fsdp_config = {}
for key, value in fsdp_config.items():
if key.startswith("fsdp_") and key != "fsdp_version":
update_fsdp_config[key.replace("fsdp_", "")] = value
else:
update_fsdp_config[key] = value
data["fsdp_config"] = update_fsdp_config
return data
@model_validator(mode="after")
def check_fsdp_offload_w_8bit_optimizer(self):
if (
data.get("fsdp_config")
and data.get("save_safetensors")
and data.get("fsdp_config")
and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT"
hasattr(self, "fsdp_config")
and self.fsdp_config
and self.optimizer
and "8bit" in self.optimizer.value
and self.fsdp_config["offload_params"]
and str(self.fsdp_version) != "2"
):
raise ValueError(
f"FSDP Offload not compatible with {str(self.optimizer.value)}"
)
return self
@model_validator(mode="after")
def check_fsdp2_w_8bit_optimizer(self):
if (
hasattr(self, "fsdp_config")
and self.fsdp_config
and self.optimizer
and "8bit" in self.optimizer.value
and str(self.fsdp_version) == "2"
):
if self.optimizer in ["adamw_8bit", "adamw_bnb_8bit"]:
# CUDA ops errors with bnb 8bit optimizer + FSDP2
raise ValueError(
f"FSDP2 not compatible with {self.optimizer.value}, use `adamw_torch_8bit` instead"
)
return self
@model_validator(mode="after")
def check_fsdp_sharded_state_dict_w_safetensors(self):
if (
hasattr(self, "fsdp_config")
and self.fsdp_config
and hasattr(self, "save_safetensors")
and self.save_safetensors
and self.fsdp_config.get("state_dict_type", "") == "SHARDED_STATE_DICT"
):
raise ValueError(
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
)
return data
return self
class SystemValidationMixin: