FSDP2 fix validation and add tests (#2910)
* fix validation and add tests * remove debugging and add more tests * remove migrate_fsdp
This commit is contained in:
@@ -16,7 +16,6 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||||
from axolotl.utils.config import (
|
from axolotl.utils.config import (
|
||||||
migrate_fsdp_config,
|
|
||||||
normalize_cfg_datasets,
|
normalize_cfg_datasets,
|
||||||
normalize_config,
|
normalize_config,
|
||||||
validate_config,
|
validate_config,
|
||||||
@@ -227,7 +226,6 @@ def load_cfg(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
migrate_fsdp_config(cfg)
|
|
||||||
prepare_optim_env(cfg)
|
prepare_optim_env(cfg)
|
||||||
prepare_opinionated_env(cfg)
|
prepare_opinionated_env(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|||||||
@@ -314,16 +314,3 @@ def prepare_plugins(cfg):
|
|||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
for plugin_name in cfg["plugins"]:
|
for plugin_name in cfg["plugins"]:
|
||||||
plugin_manager.register(plugin_name)
|
plugin_manager.register(plugin_name)
|
||||||
|
|
||||||
|
|
||||||
# TODO @SalmanMohammadi remove this function in 0.12
|
|
||||||
def migrate_fsdp_config(cfg):
|
|
||||||
if cfg.get("fsdp_config"):
|
|
||||||
fsdp_config_keys = cfg.fsdp_config.keys()
|
|
||||||
if "fsdp_version" in fsdp_config_keys:
|
|
||||||
cfg.fsdp_version = cfg.fsdp_config.pop("fsdp_version")
|
|
||||||
|
|
||||||
for key in list(fsdp_config_keys):
|
|
||||||
if key.startswith("fsdp_") and key != "fsdp_version":
|
|
||||||
cfg.fsdp_config[key.replace("fsdp_", "")] = cfg.fsdp_config[key]
|
|
||||||
del cfg.fsdp_config[key]
|
|
||||||
|
|||||||
@@ -1143,72 +1143,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_fsdp_version(cls, data):
|
|
||||||
fsdp_config = data.get("fsdp_config", {})
|
|
||||||
if fsdp_config and str(data.get("fsdp_version")) != "2":
|
|
||||||
LOG.info(
|
|
||||||
"FSDP1 will be deprecated in an upcoming release of Axolotl."
|
|
||||||
"We recommend that you use FSDP version 2 for better performance and compatibility. "
|
|
||||||
"Please see this link for more details: https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp "
|
|
||||||
"For more details on migrating your config. "
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_fsdp2_base_model_quant_ram_efficient_loading(cls, data):
|
|
||||||
fsdp_config = data.get("fsdp_config")
|
|
||||||
if fsdp_config and data.get("fsdp_version") == 2:
|
|
||||||
if fsdp_config.get("cpu_ram_efficient_loading") and (
|
|
||||||
data.get("load_in_8bit") or data.get("load_in_4bit")
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
|
|
||||||
"set fsdp_version to 1, or disable cpu_ram_efficient_loading."
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_fsdp2_base_model_quant_dpo(cls, data):
|
|
||||||
if data.get("fsdp_version") == 2 and data.get("rl") in [
|
|
||||||
RLType.DPO,
|
|
||||||
RLType.KTO,
|
|
||||||
RLType.ORPO,
|
|
||||||
RLType.IPO,
|
|
||||||
]:
|
|
||||||
if data.get("load_in_8bit") or data.get("load_in_4bit"):
|
|
||||||
raise ValueError(
|
|
||||||
"FSDP2 does not support load_in_8bit or load_in_4bit with DPO. Please use DeepSpeed or set `fsdp_version` to 1."
|
|
||||||
)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_fsdp_version_in_fsdp_config(cls, data):
|
|
||||||
if fsdp_config := data.get("fsdp_config"):
|
|
||||||
if fsdp_config.get("fsdp_version"):
|
|
||||||
LOG.warning(
|
|
||||||
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
|
|
||||||
"Please configure `fsdp_version` as a top-level field."
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_fsdp_config_kwargs_prefix(cls, data):
|
|
||||||
if fsdp_config := data.get("fsdp_config"):
|
|
||||||
for key, _ in fsdp_config.items():
|
|
||||||
if key.startswith("fsdp_"):
|
|
||||||
LOG.warning_once(
|
|
||||||
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
|
|
||||||
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_dataloader_opts(cls, data):
|
def default_dataloader_opts(cls, data):
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Module with validation methods for config pydantic model."""
|
"""Module with validation methods for config pydantic model."""
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines,too-many-boolean-expressions
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -748,44 +748,128 @@ class OptimizationValidationMixin:
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_fsdp_offload_w_8bit_optimizer(cls, data):
|
def check_fsdp_version(cls, data):
|
||||||
if (
|
fsdp_config = data.get("fsdp_config", {})
|
||||||
data.get("fsdp")
|
if fsdp_config and str(data.get("fsdp_version")) != "2":
|
||||||
and "8bit" in data.get("optimizer", "")
|
LOG.info(
|
||||||
and data.get("fsdp_config")
|
"FSDP1 will be deprecated in an upcoming release of Axolotl."
|
||||||
and data["fsdp_config"].get("fsdp_offload_params")
|
"We recommend that you use FSDP version 2 for better performance and compatibility. "
|
||||||
and str(data["fsdp_config"].get("fsdp_version")) != "2"
|
"Please see this link for more details: https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp "
|
||||||
):
|
"For more details on migrating your config. "
|
||||||
raise ValueError(
|
|
||||||
f"FSDP Offload not compatible with {data.get('optimizer')}"
|
|
||||||
)
|
)
|
||||||
if (
|
return data
|
||||||
data.get("fsdp")
|
|
||||||
and "8bit" in data.get("optimizer", "")
|
@model_validator(mode="after")
|
||||||
and data.get("fsdp_config")
|
def check_fsdp2_base_model_quant_ram_efficient_loading(self):
|
||||||
and str(data["fsdp_config"].get("fsdp_version")) == "2"
|
fsdp_config = self.fsdp_config if hasattr(self, "fsdp_config") else None
|
||||||
):
|
fsdp_version = self.fsdp_version if hasattr(self, "fsdp_version") else None
|
||||||
if data.get("optimizer", "") in ["adamw_8bit", "adamw_bnb_8bit"]:
|
load_in_8bit = self.load_in_8bit if hasattr(self, "load_in_8bit") else None
|
||||||
# CUDA ops errors with bnb 8bit optimizer + FSDP2
|
load_in_4bit = self.load_in_4bit if hasattr(self, "load_in_4bit") else None
|
||||||
|
if fsdp_config and fsdp_version == 2:
|
||||||
|
if fsdp_config.get("cpu_ram_efficient_loading") and (
|
||||||
|
load_in_8bit or load_in_4bit
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"FSDP2 not compatible with {data.get('optimizer')}, use `adamw_torch_8bit` instead"
|
"FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
|
||||||
|
"set fsdp_version to 1, or disable cpu_ram_efficient_loading."
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_fsdp2_base_model_quant_rl(cls, data):
|
||||||
|
if data.get("fsdp_version") == 2 and data.get("rl") in [
|
||||||
|
RLType.DPO,
|
||||||
|
RLType.KTO,
|
||||||
|
RLType.ORPO,
|
||||||
|
RLType.IPO,
|
||||||
|
]:
|
||||||
|
if data.get("load_in_8bit") or data.get("load_in_4bit"):
|
||||||
|
raise ValueError(
|
||||||
|
f"FSDP2 does not support load_in_8bit or load_in_4bit with {data.get('rl')}. Please use DeepSpeed or set `fsdp_version` to 1."
|
||||||
)
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_fsdp_sharded_state_dict_w_safetensors(cls, data):
|
def check_fsdp_version_in_fsdp_config(cls, data):
|
||||||
|
if data.get("fsdp_config"):
|
||||||
|
if data.get("fsdp_config", {}).get("fsdp_version"):
|
||||||
|
LOG.warning(
|
||||||
|
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
|
||||||
|
"Please configure `fsdp_version` as a top-level field."
|
||||||
|
)
|
||||||
|
data["fsdp_version"] = data.get("fsdp_config").pop("fsdp_version")
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_fsdp_config_kwargs_prefix(cls, data):
|
||||||
|
if fsdp_config := data.get("fsdp_config"):
|
||||||
|
should_fix = False
|
||||||
|
for key, _ in fsdp_config.items():
|
||||||
|
if key.startswith("fsdp_"):
|
||||||
|
should_fix = True
|
||||||
|
LOG.warning_once(
|
||||||
|
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
|
||||||
|
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
|
||||||
|
)
|
||||||
|
if should_fix:
|
||||||
|
update_fsdp_config = {}
|
||||||
|
for key, value in fsdp_config.items():
|
||||||
|
if key.startswith("fsdp_") and key != "fsdp_version":
|
||||||
|
update_fsdp_config[key.replace("fsdp_", "")] = value
|
||||||
|
else:
|
||||||
|
update_fsdp_config[key] = value
|
||||||
|
data["fsdp_config"] = update_fsdp_config
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_fsdp_offload_w_8bit_optimizer(self):
|
||||||
if (
|
if (
|
||||||
data.get("fsdp_config")
|
hasattr(self, "fsdp_config")
|
||||||
and data.get("save_safetensors")
|
and self.fsdp_config
|
||||||
and data.get("fsdp_config")
|
and self.optimizer
|
||||||
and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT"
|
and "8bit" in self.optimizer.value
|
||||||
|
and self.fsdp_config["offload_params"]
|
||||||
|
and str(self.fsdp_version) != "2"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"FSDP Offload not compatible with {str(self.optimizer.value)}"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_fsdp2_w_8bit_optimizer(self):
|
||||||
|
if (
|
||||||
|
hasattr(self, "fsdp_config")
|
||||||
|
and self.fsdp_config
|
||||||
|
and self.optimizer
|
||||||
|
and "8bit" in self.optimizer.value
|
||||||
|
and str(self.fsdp_version) == "2"
|
||||||
|
):
|
||||||
|
if self.optimizer in ["adamw_8bit", "adamw_bnb_8bit"]:
|
||||||
|
# CUDA ops errors with bnb 8bit optimizer + FSDP2
|
||||||
|
raise ValueError(
|
||||||
|
f"FSDP2 not compatible with {self.optimizer.value}, use `adamw_torch_8bit` instead"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_fsdp_sharded_state_dict_w_safetensors(self):
|
||||||
|
if (
|
||||||
|
hasattr(self, "fsdp_config")
|
||||||
|
and self.fsdp_config
|
||||||
|
and hasattr(self, "save_safetensors")
|
||||||
|
and self.save_safetensors
|
||||||
|
and self.fsdp_config.get("state_dict_type", "") == "SHARDED_STATE_DICT"
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
|
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
|
||||||
)
|
)
|
||||||
return data
|
return self
|
||||||
|
|
||||||
|
|
||||||
class SystemValidationMixin:
|
class SystemValidationMixin:
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ import unittest
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from axolotl.utils.config import (
|
from axolotl.utils.config import (
|
||||||
migrate_fsdp_config,
|
|
||||||
normalize_cfg_datasets,
|
normalize_cfg_datasets,
|
||||||
normalize_config,
|
normalize_config,
|
||||||
|
validate_config,
|
||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -27,6 +27,13 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"micro_batch_size": 1,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"learning_rate": 0.0001,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -97,7 +104,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def test_migrate_fsdp_config(self):
|
def test_migrate_fsdp_config(self):
|
||||||
"""Test basic FSDP config migration with and without fsdp_version"""
|
"""Test basic FSDP config migration with and without fsdp_version"""
|
||||||
cfg_with_version = DictDefault(
|
cfg_with_version = self._get_base_cfg() | DictDefault(
|
||||||
{
|
{
|
||||||
"fsdp_config": {
|
"fsdp_config": {
|
||||||
"fsdp_version": 2,
|
"fsdp_version": 2,
|
||||||
@@ -109,7 +116,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
migrate_fsdp_config(cfg_with_version)
|
cfg_with_version = validate_config(cfg_with_version)
|
||||||
|
|
||||||
self.assertEqual(cfg_with_version.fsdp_version, 2)
|
self.assertEqual(cfg_with_version.fsdp_version, 2)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@@ -125,7 +132,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
|||||||
self.assertNotIn("fsdp_version", cfg_with_version.fsdp_config)
|
self.assertNotIn("fsdp_version", cfg_with_version.fsdp_config)
|
||||||
self.assertNotIn("version", cfg_with_version.fsdp_config)
|
self.assertNotIn("version", cfg_with_version.fsdp_config)
|
||||||
|
|
||||||
cfg_without_version = DictDefault(
|
cfg_without_version = self._get_base_cfg() | DictDefault(
|
||||||
{
|
{
|
||||||
"fsdp_config": {
|
"fsdp_config": {
|
||||||
"fsdp_auto_wrap_policy": "SIZE_BASED_WRAP",
|
"fsdp_auto_wrap_policy": "SIZE_BASED_WRAP",
|
||||||
@@ -135,7 +142,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
migrate_fsdp_config(cfg_without_version)
|
cfg_without_version = validate_config(cfg_without_version)
|
||||||
|
|
||||||
self.assertNotIn("fsdp_version", cfg_without_version)
|
self.assertNotIn("fsdp_version", cfg_without_version)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@@ -149,26 +156,25 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def test_migrate_fsdp_config_no_fsdp_config(self):
|
def test_migrate_fsdp_config_no_fsdp_config(self):
|
||||||
"""Test that function doesn't crash when no fsdp_config is present"""
|
"""Test that function doesn't crash when no fsdp_config is present"""
|
||||||
cfg = DictDefault({"some_other_config": "value"})
|
cfg = self._get_base_cfg()
|
||||||
|
|
||||||
migrate_fsdp_config(cfg)
|
cfg = validate_config(cfg)
|
||||||
|
|
||||||
self.assertNotIn("fsdp_config", cfg)
|
self.assertNotIn("fsdp_config", cfg)
|
||||||
self.assertNotIn("fsdp_version", cfg)
|
self.assertNotIn("fsdp_version", cfg)
|
||||||
self.assertEqual(cfg.some_other_config, "value")
|
|
||||||
|
|
||||||
def test_migrate_fsdp_config_empty_fsdp_config(self):
|
def test_migrate_fsdp_config_empty_fsdp_config(self):
|
||||||
"""Test migration with empty fsdp_config"""
|
"""Test migration with empty fsdp_config"""
|
||||||
cfg = DictDefault({"fsdp_config": {}})
|
cfg = self._get_base_cfg() | DictDefault({"fsdp_config": {}})
|
||||||
|
|
||||||
migrate_fsdp_config(cfg)
|
cfg = validate_config(cfg)
|
||||||
|
|
||||||
self.assertNotIn("fsdp_version", cfg)
|
self.assertNotIn("fsdp_version", cfg)
|
||||||
self.assertEqual(cfg.fsdp_config, {})
|
self.assertEqual(cfg.fsdp_config, {})
|
||||||
|
|
||||||
def test_migrate_fsdp_config_mixed_keys(self):
|
def test_migrate_fsdp_config_mixed_keys(self):
|
||||||
"""Test migration with a mix of fsdp_ and non-fsdp_ keys"""
|
"""Test migration with a mix of fsdp_ and non-fsdp_ keys"""
|
||||||
cfg = DictDefault(
|
cfg = self._get_base_cfg() | DictDefault(
|
||||||
{
|
{
|
||||||
"fsdp_config": {
|
"fsdp_config": {
|
||||||
"fsdp_version": 1,
|
"fsdp_version": 1,
|
||||||
@@ -180,7 +186,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
migrate_fsdp_config(cfg)
|
cfg = validate_config(cfg)
|
||||||
|
|
||||||
self.assertEqual(cfg.fsdp_version, 1)
|
self.assertEqual(cfg.fsdp_version, 1)
|
||||||
self.assertEqual(cfg.fsdp_config.state_dict_type, "FULL_STATE_DICT")
|
self.assertEqual(cfg.fsdp_config.state_dict_type, "FULL_STATE_DICT")
|
||||||
|
|||||||
155
tests/utils/schemas/validation/test_fsdp.py
Normal file
155
tests/utils/schemas/validation/test_fsdp.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
"""
|
||||||
|
tests for pydantic fsdp validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=too-many-boolean-expressions
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.config import validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="fsdp_base_cfg")
|
||||||
|
def fixture_fsdp_base_cfg():
|
||||||
|
return DictDefault(
|
||||||
|
base_model="gpt2",
|
||||||
|
learning_rate=1e-3,
|
||||||
|
datasets=[
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
micro_batch_size=1,
|
||||||
|
gradient_accumulation_steps=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFSDPValidation:
|
||||||
|
"""
|
||||||
|
test class for pydantic fsdp validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_fsdp_version_in_fsdp_config(self, fsdp_base_cfg):
|
||||||
|
cfg = fsdp_base_cfg | DictDefault(
|
||||||
|
fsdp_config={
|
||||||
|
"fsdp_version": 2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
cfg = validate_config(
|
||||||
|
cfg,
|
||||||
|
)
|
||||||
|
assert cfg.fsdp_version == 2
|
||||||
|
assert cfg.fsdp_config.fsdp_version is None
|
||||||
|
|
||||||
|
def test_fsdp_sharded_state_dict_safetensors(self, fsdp_base_cfg):
|
||||||
|
cfg = fsdp_base_cfg | DictDefault(
|
||||||
|
fsdp_config={
|
||||||
|
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||||
|
},
|
||||||
|
save_safetensors=True,
|
||||||
|
)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
# test w/o prefix too
|
||||||
|
cfg = fsdp_base_cfg | DictDefault(
|
||||||
|
fsdp_config={
|
||||||
|
"state_dict_type": "SHARDED_STATE_DICT",
|
||||||
|
},
|
||||||
|
save_safetensors=True,
|
||||||
|
)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_fsdp_offload_w_8bit_optim(self, fsdp_base_cfg):
|
||||||
|
cfg = fsdp_base_cfg | DictDefault(
|
||||||
|
fsdp_config={
|
||||||
|
"offload_params": True,
|
||||||
|
},
|
||||||
|
optimizer="adamw_8bit",
|
||||||
|
fsdp_version=1,
|
||||||
|
)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="FSDP Offload not compatible with adamw_8bit"
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_fsdp2_w_8bit_optim(self, fsdp_base_cfg):
|
||||||
|
cfg = fsdp_base_cfg | DictDefault(
|
||||||
|
fsdp_config={
|
||||||
|
"offload_params": True,
|
||||||
|
},
|
||||||
|
optimizer="adamw_8bit",
|
||||||
|
fsdp_version=2,
|
||||||
|
)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="FSDP2 not compatible with adamw_8bit, use `adamw_torch_8bit` instead",
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_fsdp2_w_cpu_ram_efficient_loading(self, fsdp_base_cfg):
|
||||||
|
cfg = fsdp_base_cfg | DictDefault(
|
||||||
|
load_in_8bit=True,
|
||||||
|
adapter="lora",
|
||||||
|
fsdp_config={
|
||||||
|
"cpu_ram_efficient_loading": True,
|
||||||
|
},
|
||||||
|
fsdp_version=2,
|
||||||
|
)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading.",
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_fsdp_prefixes_removed(self, fsdp_base_cfg):
|
||||||
|
cfg = fsdp_base_cfg | DictDefault(
|
||||||
|
fsdp_config={
|
||||||
|
"fsdp_version": 2,
|
||||||
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
|
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||||
|
"fsdp_reshard_after_forward": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
assert cfg.fsdp_version == 2
|
||||||
|
assert cfg.fsdp_config.fsdp_version is None
|
||||||
|
for keys in cfg.fsdp_config.keys():
|
||||||
|
assert not keys.startswith("fsdp_")
|
||||||
|
assert cfg.fsdp_config.auto_wrap_policy == "TRANSFORMER_BASED_WRAP"
|
||||||
|
assert cfg.fsdp_config.transformer_layer_cls_to_wrap == "LlamaDecoderLayer"
|
||||||
|
assert cfg.fsdp_config.reshard_after_forward is True
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"rl",
|
||||||
|
[
|
||||||
|
"dpo",
|
||||||
|
"kto",
|
||||||
|
"orpo",
|
||||||
|
"ipo",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_fsdp2_dpo(self, fsdp_base_cfg, rl):
|
||||||
|
cfg = fsdp_base_cfg | DictDefault(
|
||||||
|
fsdp_version=2,
|
||||||
|
fsdp_config={
|
||||||
|
"reshard_after_forward": True,
|
||||||
|
},
|
||||||
|
rl=rl,
|
||||||
|
load_in_8bit=True,
|
||||||
|
adapter="lora",
|
||||||
|
remove_unused_columns=False,
|
||||||
|
)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="FSDP2 does not support load_in_8bit or load_in_4bit with ",
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
Reference in New Issue
Block a user