diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 5f75352f3..cb0eece7f 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -16,7 +16,6 @@ from transformers.utils import is_torch_bf16_gpu_available from axolotl.integrations.base import PluginManager from axolotl.utils.comet_ import setup_comet_env_vars from axolotl.utils.config import ( - migrate_fsdp_config, normalize_cfg_datasets, normalize_config, validate_config, @@ -227,7 +226,6 @@ def load_cfg( }, ) - migrate_fsdp_config(cfg) prepare_optim_env(cfg) prepare_opinionated_env(cfg) normalize_config(cfg) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 4e26a257d..aaa203e82 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -314,16 +314,3 @@ def prepare_plugins(cfg): plugin_manager = PluginManager.get_instance() for plugin_name in cfg["plugins"]: plugin_manager.register(plugin_name) - - -# TODO @SalmanMohammadi remove this function in 0.12 -def migrate_fsdp_config(cfg): - if cfg.get("fsdp_config"): - fsdp_config_keys = cfg.fsdp_config.keys() - if "fsdp_version" in fsdp_config_keys: - cfg.fsdp_version = cfg.fsdp_config.pop("fsdp_version") - - for key in list(fsdp_config_keys): - if key.startswith("fsdp_") and key != "fsdp_version": - cfg.fsdp_config[key.replace("fsdp_", "")] = cfg.fsdp_config[key] - del cfg.fsdp_config[key] diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index de80d1b79..6668380bf 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1143,72 +1143,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): return data - @model_validator(mode="before") - @classmethod - def check_fsdp_version(cls, data): - fsdp_config = data.get("fsdp_config", {}) - if fsdp_config and str(data.get("fsdp_version")) != "2": - LOG.info( - "FSDP1 will be deprecated in an upcoming release of Axolotl." - "We recommend that you use FSDP version 2 for better performance and compatibility. " - "Please see this link for more details: https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp " - "For more details on migrating your config. " - ) - return data - - @model_validator(mode="before") - @classmethod - def check_fsdp2_base_model_quant_ram_efficient_loading(cls, data): - fsdp_config = data.get("fsdp_config") - if fsdp_config and data.get("fsdp_version") == 2: - if fsdp_config.get("cpu_ram_efficient_loading") and ( - data.get("load_in_8bit") or data.get("load_in_4bit") - ): - raise ValueError( - "FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, " - "set fsdp_version to 1, or disable cpu_ram_efficient_loading." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_fsdp2_base_model_quant_dpo(cls, data): - if data.get("fsdp_version") == 2 and data.get("rl") in [ - RLType.DPO, - RLType.KTO, - RLType.ORPO, - RLType.IPO, - ]: - if data.get("load_in_8bit") or data.get("load_in_4bit"): - raise ValueError( - "FSDP2 does not support load_in_8bit or load_in_4bit with DPO. Please use DeepSpeed or set `fsdp_version` to 1." - ) - - return data - - @model_validator(mode="before") - @classmethod - def check_fsdp_version_in_fsdp_config(cls, data): - if fsdp_config := data.get("fsdp_config"): - if fsdp_config.get("fsdp_version"): - LOG.warning( - "Configuring `fsdp_version` in `fsdp_config` is deprecated. " - "Please configure `fsdp_version` as a top-level field." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_fsdp_config_kwargs_prefix(cls, data): - if fsdp_config := data.get("fsdp_config"): - for key, _ in fsdp_config.items(): - if key.startswith("fsdp_"): - LOG.warning_once( - "Configuring FSDP fields with the `fsdp_` prefix is deprecated. " - "Please omit the `fsdp_` prefix from the any fields in `fsdp_config`." - ) - return data - @model_validator(mode="before") @classmethod def default_dataloader_opts(cls, data): diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 57959c4fa..534d89a98 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1,6 +1,6 @@ """Module with validation methods for config pydantic model.""" -# pylint: disable=too-many-lines +# pylint: disable=too-many-lines,too-many-boolean-expressions import logging @@ -748,44 +748,128 @@ class OptimizationValidationMixin: @model_validator(mode="before") @classmethod - def check_fsdp_offload_w_8bit_optimizer(cls, data): - if ( - data.get("fsdp") - and "8bit" in data.get("optimizer", "") - and data.get("fsdp_config") - and data["fsdp_config"].get("fsdp_offload_params") - and str(data["fsdp_config"].get("fsdp_version")) != "2" - ): - raise ValueError( - f"FSDP Offload not compatible with {data.get('optimizer')}" + def check_fsdp_version(cls, data): + fsdp_config = data.get("fsdp_config", {}) + if fsdp_config and str(data.get("fsdp_version")) != "2": + LOG.info( + "FSDP1 will be deprecated in an upcoming release of Axolotl." + "We recommend that you use FSDP version 2 for better performance and compatibility. " + "Please see this link for more details: https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp " + "For more details on migrating your config. " ) - if ( - data.get("fsdp") - and "8bit" in data.get("optimizer", "") - and data.get("fsdp_config") - and str(data["fsdp_config"].get("fsdp_version")) == "2" - ): - if data.get("optimizer", "") in ["adamw_8bit", "adamw_bnb_8bit"]: - # CUDA ops errors with bnb 8bit optimizer + FSDP2 + return data + + @model_validator(mode="after") + def check_fsdp2_base_model_quant_ram_efficient_loading(self): + fsdp_config = self.fsdp_config if hasattr(self, "fsdp_config") else None + fsdp_version = self.fsdp_version if hasattr(self, "fsdp_version") else None + load_in_8bit = self.load_in_8bit if hasattr(self, "load_in_8bit") else None + load_in_4bit = self.load_in_4bit if hasattr(self, "load_in_4bit") else None + if fsdp_config and fsdp_version == 2: + if fsdp_config.get("cpu_ram_efficient_loading") and ( + load_in_8bit or load_in_4bit + ): raise ValueError( - f"FSDP2 not compatible with {data.get('optimizer')}, use `adamw_torch_8bit` instead" + "FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, " + "set fsdp_version to 1, or disable cpu_ram_efficient_loading." + ) + return self + + @model_validator(mode="before") + @classmethod + def check_fsdp2_base_model_quant_rl(cls, data): + if data.get("fsdp_version") == 2 and data.get("rl") in [ + RLType.DPO, + RLType.KTO, + RLType.ORPO, + RLType.IPO, + ]: + if data.get("load_in_8bit") or data.get("load_in_4bit"): + raise ValueError( + f"FSDP2 does not support load_in_8bit or load_in_4bit with {data.get('rl')}. Please use DeepSpeed or set `fsdp_version` to 1." ) return data @model_validator(mode="before") @classmethod - def check_fsdp_sharded_state_dict_w_safetensors(cls, data): + def check_fsdp_version_in_fsdp_config(cls, data): + if data.get("fsdp_config"): + if data.get("fsdp_config", {}).get("fsdp_version"): + LOG.warning( + "Configuring `fsdp_version` in `fsdp_config` is deprecated. " + "Please configure `fsdp_version` as a top-level field." + ) + data["fsdp_version"] = data.get("fsdp_config").pop("fsdp_version") + return data + + @model_validator(mode="before") + @classmethod + def check_fsdp_config_kwargs_prefix(cls, data): + if fsdp_config := data.get("fsdp_config"): + should_fix = False + for key, _ in fsdp_config.items(): + if key.startswith("fsdp_"): + should_fix = True + LOG.warning_once( + "Configuring FSDP fields with the `fsdp_` prefix is deprecated. " + "Please omit the `fsdp_` prefix from the any fields in `fsdp_config`." + ) + if should_fix: + update_fsdp_config = {} + for key, value in fsdp_config.items(): + if key.startswith("fsdp_") and key != "fsdp_version": + update_fsdp_config[key.replace("fsdp_", "")] = value + else: + update_fsdp_config[key] = value + data["fsdp_config"] = update_fsdp_config + return data + + @model_validator(mode="after") + def check_fsdp_offload_w_8bit_optimizer(self): if ( - data.get("fsdp_config") - and data.get("save_safetensors") - and data.get("fsdp_config") - and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT" + hasattr(self, "fsdp_config") + and self.fsdp_config + and self.optimizer + and "8bit" in self.optimizer.value + and self.fsdp_config["offload_params"] + and str(self.fsdp_version) != "2" + ): + raise ValueError( + f"FSDP Offload not compatible with {str(self.optimizer.value)}" + ) + return self + + @model_validator(mode="after") + def check_fsdp2_w_8bit_optimizer(self): + if ( + hasattr(self, "fsdp_config") + and self.fsdp_config + and self.optimizer + and "8bit" in self.optimizer.value + and str(self.fsdp_version) == "2" + ): + if self.optimizer in ["adamw_8bit", "adamw_bnb_8bit"]: + # CUDA ops errors with bnb 8bit optimizer + FSDP2 + raise ValueError( + f"FSDP2 not compatible with {self.optimizer.value}, use `adamw_torch_8bit` instead" + ) + + return self + + @model_validator(mode="after") + def check_fsdp_sharded_state_dict_w_safetensors(self): + if ( + hasattr(self, "fsdp_config") + and self.fsdp_config + and hasattr(self, "save_safetensors") + and self.save_safetensors + and self.fsdp_config.get("state_dict_type", "") == "SHARDED_STATE_DICT" ): raise ValueError( "FSDP SHARDED_STATE_DICT not compatible with save_safetensors" ) - return data + return self class SystemValidationMixin: diff --git a/tests/test_normalize_config.py b/tests/test_normalize_config.py index 31d04fc64..658e06fcb 100644 --- a/tests/test_normalize_config.py +++ b/tests/test_normalize_config.py @@ -6,9 +6,9 @@ import unittest from unittest.mock import patch from axolotl.utils.config import ( - migrate_fsdp_config, normalize_cfg_datasets, normalize_config, + validate_config, ) from axolotl.utils.dict import DictDefault @@ -27,6 +27,13 @@ class NormalizeConfigTestCase(unittest.TestCase): "num_epochs": 1, "micro_batch_size": 1, "gradient_accumulation_steps": 1, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "learning_rate": 0.0001, } ) @@ -97,7 +104,7 @@ class NormalizeConfigTestCase(unittest.TestCase): def test_migrate_fsdp_config(self): """Test basic FSDP config migration with and without fsdp_version""" - cfg_with_version = DictDefault( + cfg_with_version = self._get_base_cfg() | DictDefault( { "fsdp_config": { "fsdp_version": 2, @@ -109,7 +116,7 @@ class NormalizeConfigTestCase(unittest.TestCase): } ) - migrate_fsdp_config(cfg_with_version) + cfg_with_version = validate_config(cfg_with_version) self.assertEqual(cfg_with_version.fsdp_version, 2) self.assertEqual( @@ -125,7 +132,7 @@ class NormalizeConfigTestCase(unittest.TestCase): self.assertNotIn("fsdp_version", cfg_with_version.fsdp_config) self.assertNotIn("version", cfg_with_version.fsdp_config) - cfg_without_version = DictDefault( + cfg_without_version = self._get_base_cfg() | DictDefault( { "fsdp_config": { "fsdp_auto_wrap_policy": "SIZE_BASED_WRAP", @@ -135,7 +142,7 @@ class NormalizeConfigTestCase(unittest.TestCase): } ) - migrate_fsdp_config(cfg_without_version) + cfg_without_version = validate_config(cfg_without_version) self.assertNotIn("fsdp_version", cfg_without_version) self.assertEqual( @@ -149,26 +156,25 @@ class NormalizeConfigTestCase(unittest.TestCase): def test_migrate_fsdp_config_no_fsdp_config(self): """Test that function doesn't crash when no fsdp_config is present""" - cfg = DictDefault({"some_other_config": "value"}) + cfg = self._get_base_cfg() - migrate_fsdp_config(cfg) + cfg = validate_config(cfg) self.assertNotIn("fsdp_config", cfg) self.assertNotIn("fsdp_version", cfg) - self.assertEqual(cfg.some_other_config, "value") def test_migrate_fsdp_config_empty_fsdp_config(self): """Test migration with empty fsdp_config""" - cfg = DictDefault({"fsdp_config": {}}) + cfg = self._get_base_cfg() | DictDefault({"fsdp_config": {}}) - migrate_fsdp_config(cfg) + cfg = validate_config(cfg) self.assertNotIn("fsdp_version", cfg) self.assertEqual(cfg.fsdp_config, {}) def test_migrate_fsdp_config_mixed_keys(self): """Test migration with a mix of fsdp_ and non-fsdp_ keys""" - cfg = DictDefault( + cfg = self._get_base_cfg() | DictDefault( { "fsdp_config": { "fsdp_version": 1, @@ -180,7 +186,7 @@ class NormalizeConfigTestCase(unittest.TestCase): } ) - migrate_fsdp_config(cfg) + cfg = validate_config(cfg) self.assertEqual(cfg.fsdp_version, 1) self.assertEqual(cfg.fsdp_config.state_dict_type, "FULL_STATE_DICT") diff --git a/tests/utils/schemas/validation/test_fsdp.py b/tests/utils/schemas/validation/test_fsdp.py new file mode 100644 index 000000000..456040bc1 --- /dev/null +++ b/tests/utils/schemas/validation/test_fsdp.py @@ -0,0 +1,155 @@ +""" +tests for pydantic fsdp validation +""" + +# pylint: disable=too-many-boolean-expressions +import pytest + +from axolotl.utils.config import validate_config +from axolotl.utils.dict import DictDefault + + +@pytest.fixture(name="fsdp_base_cfg") +def fixture_fsdp_base_cfg(): + return DictDefault( + base_model="gpt2", + learning_rate=1e-3, + datasets=[ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + micro_batch_size=1, + gradient_accumulation_steps=1, + ) + + +class TestFSDPValidation: + """ + test class for pydantic fsdp validation + """ + + def test_fsdp_version_in_fsdp_config(self, fsdp_base_cfg): + cfg = fsdp_base_cfg | DictDefault( + fsdp_config={ + "fsdp_version": 2, + }, + ) + cfg = validate_config( + cfg, + ) + assert cfg.fsdp_version == 2 + assert cfg.fsdp_config.fsdp_version is None + + def test_fsdp_sharded_state_dict_safetensors(self, fsdp_base_cfg): + cfg = fsdp_base_cfg | DictDefault( + fsdp_config={ + "fsdp_state_dict_type": "SHARDED_STATE_DICT", + }, + save_safetensors=True, + ) + with pytest.raises( + ValueError, + match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors", + ): + validate_config(cfg) + + # test w/o prefix too + cfg = fsdp_base_cfg | DictDefault( + fsdp_config={ + "state_dict_type": "SHARDED_STATE_DICT", + }, + save_safetensors=True, + ) + with pytest.raises( + ValueError, + match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors", + ): + validate_config(cfg) + + def test_fsdp_offload_w_8bit_optim(self, fsdp_base_cfg): + cfg = fsdp_base_cfg | DictDefault( + fsdp_config={ + "offload_params": True, + }, + optimizer="adamw_8bit", + fsdp_version=1, + ) + with pytest.raises( + ValueError, match="FSDP Offload not compatible with adamw_8bit" + ): + validate_config(cfg) + + def test_fsdp2_w_8bit_optim(self, fsdp_base_cfg): + cfg = fsdp_base_cfg | DictDefault( + fsdp_config={ + "offload_params": True, + }, + optimizer="adamw_8bit", + fsdp_version=2, + ) + with pytest.raises( + ValueError, + match="FSDP2 not compatible with adamw_8bit, use `adamw_torch_8bit` instead", + ): + validate_config(cfg) + + def test_fsdp2_w_cpu_ram_efficient_loading(self, fsdp_base_cfg): + cfg = fsdp_base_cfg | DictDefault( + load_in_8bit=True, + adapter="lora", + fsdp_config={ + "cpu_ram_efficient_loading": True, + }, + fsdp_version=2, + ) + with pytest.raises( + ValueError, + match="FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading.", + ): + validate_config(cfg) + + def test_fsdp_prefixes_removed(self, fsdp_base_cfg): + cfg = fsdp_base_cfg | DictDefault( + fsdp_config={ + "fsdp_version": 2, + "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer", + "fsdp_reshard_after_forward": True, + } + ) + cfg = validate_config(cfg) + assert cfg.fsdp_version == 2 + assert cfg.fsdp_config.fsdp_version is None + for keys in cfg.fsdp_config.keys(): + assert not keys.startswith("fsdp_") + assert cfg.fsdp_config.auto_wrap_policy == "TRANSFORMER_BASED_WRAP" + assert cfg.fsdp_config.transformer_layer_cls_to_wrap == "LlamaDecoderLayer" + assert cfg.fsdp_config.reshard_after_forward is True + + @pytest.mark.parametrize( + "rl", + [ + "dpo", + "kto", + "orpo", + "ipo", + ], + ) + def test_fsdp2_dpo(self, fsdp_base_cfg, rl): + cfg = fsdp_base_cfg | DictDefault( + fsdp_version=2, + fsdp_config={ + "reshard_after_forward": True, + }, + rl=rl, + load_in_8bit=True, + adapter="lora", + remove_unused_columns=False, + ) + with pytest.raises( + ValueError, + match="FSDP2 does not support load_in_8bit or load_in_4bit with ", + ): + validate_config(cfg)