FSDP1 -> FSDP2 (#2760)
* FSDP2 args migration implementation This commit implements the migration to FSDP2 arguments including: - FSDP2 support with LoRA training - DPO integration with FSDP2 - Model loading fixes and refactoring - CPU offloading and PEFT handling - Test updates and CI improvements - Bug fixes for dtype errors and various edge cases
This commit is contained in:
@@ -5,7 +5,11 @@ Test classes for checking functionality of the cfg normalization
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.utils.config import normalize_cfg_datasets, normalize_config
|
||||
from axolotl.utils.config import (
|
||||
migrate_fsdp_config,
|
||||
normalize_cfg_datasets,
|
||||
normalize_config,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@@ -90,3 +94,104 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
|
||||
self.assertTrue(cfg.bf16)
|
||||
self.assertFalse(cfg.fp16)
|
||||
|
||||
def test_migrate_fsdp_config(self):
|
||||
"""Test basic FSDP config migration with and without fsdp_version"""
|
||||
cfg_with_version = DictDefault(
|
||||
{
|
||||
"fsdp_config": {
|
||||
"fsdp_version": 2,
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": True,
|
||||
"regular_param": "value",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
migrate_fsdp_config(cfg_with_version)
|
||||
|
||||
self.assertEqual(cfg_with_version.fsdp_version, 2)
|
||||
self.assertEqual(
|
||||
cfg_with_version.fsdp_config.auto_wrap_policy, "TRANSFORMER_BASED_WRAP"
|
||||
)
|
||||
self.assertEqual(cfg_with_version.fsdp_config.offload_params, False)
|
||||
self.assertEqual(cfg_with_version.fsdp_config.cpu_ram_efficient_loading, True)
|
||||
self.assertEqual(cfg_with_version.fsdp_config.regular_param, "value")
|
||||
|
||||
self.assertNotIn("fsdp_auto_wrap_policy", cfg_with_version.fsdp_config)
|
||||
self.assertNotIn("fsdp_offload_params", cfg_with_version.fsdp_config)
|
||||
self.assertNotIn("fsdp_cpu_ram_efficient_loading", cfg_with_version.fsdp_config)
|
||||
self.assertNotIn("fsdp_version", cfg_with_version.fsdp_config)
|
||||
self.assertNotIn("version", cfg_with_version.fsdp_config)
|
||||
|
||||
cfg_without_version = DictDefault(
|
||||
{
|
||||
"fsdp_config": {
|
||||
"fsdp_auto_wrap_policy": "SIZE_BASED_WRAP",
|
||||
"fsdp_offload_params": True,
|
||||
"regular_param": "value",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
migrate_fsdp_config(cfg_without_version)
|
||||
|
||||
self.assertNotIn("fsdp_version", cfg_without_version)
|
||||
self.assertEqual(
|
||||
cfg_without_version.fsdp_config.auto_wrap_policy, "SIZE_BASED_WRAP"
|
||||
)
|
||||
self.assertEqual(cfg_without_version.fsdp_config.offload_params, True)
|
||||
self.assertEqual(cfg_without_version.fsdp_config.regular_param, "value")
|
||||
|
||||
self.assertNotIn("fsdp_auto_wrap_policy", cfg_without_version.fsdp_config)
|
||||
self.assertNotIn("fsdp_offload_params", cfg_without_version.fsdp_config)
|
||||
|
||||
def test_migrate_fsdp_config_no_fsdp_config(self):
|
||||
"""Test that function doesn't crash when no fsdp_config is present"""
|
||||
cfg = DictDefault({"some_other_config": "value"})
|
||||
|
||||
migrate_fsdp_config(cfg)
|
||||
|
||||
self.assertNotIn("fsdp_config", cfg)
|
||||
self.assertNotIn("fsdp_version", cfg)
|
||||
self.assertEqual(cfg.some_other_config, "value")
|
||||
|
||||
def test_migrate_fsdp_config_empty_fsdp_config(self):
|
||||
"""Test migration with empty fsdp_config"""
|
||||
cfg = DictDefault({"fsdp_config": {}})
|
||||
|
||||
migrate_fsdp_config(cfg)
|
||||
|
||||
self.assertNotIn("fsdp_version", cfg)
|
||||
self.assertEqual(cfg.fsdp_config, {})
|
||||
|
||||
def test_migrate_fsdp_config_mixed_keys(self):
|
||||
"""Test migration with a mix of fsdp_ and non-fsdp_ keys"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"fsdp_config": {
|
||||
"fsdp_version": 1,
|
||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||
"mixed_precision_policy": "fp16",
|
||||
"activation_checkpointing": True,
|
||||
"fsdp_reshard_after_forward": False,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
migrate_fsdp_config(cfg)
|
||||
|
||||
self.assertEqual(cfg.fsdp_version, 1)
|
||||
self.assertEqual(cfg.fsdp_config.state_dict_type, "FULL_STATE_DICT")
|
||||
self.assertEqual(cfg.fsdp_config.reshard_after_forward, False)
|
||||
self.assertEqual(cfg.fsdp_config.mixed_precision_policy, "fp16")
|
||||
self.assertEqual(cfg.fsdp_config.activation_checkpointing, True)
|
||||
|
||||
# Check original fsdp_ keys are removed
|
||||
self.assertNotIn("fsdp_version", cfg.fsdp_config)
|
||||
self.assertNotIn("fsdp_state_dict_type", cfg.fsdp_config)
|
||||
self.assertNotIn("fsdp_reshard_after_forward", cfg.fsdp_config)
|
||||
|
||||
# Ensure no duplicate version key
|
||||
self.assertNotIn("version", cfg.fsdp_config)
|
||||
|
||||
Reference in New Issue
Block a user