FSDP2 fix validation and add tests (#2910)
* fix validation and add tests * remove debugging and add more tests * remove migrate_fsdp
This commit is contained in:
@@ -6,9 +6,9 @@ import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.utils.config import (
|
||||
migrate_fsdp_config,
|
||||
normalize_cfg_datasets,
|
||||
normalize_config,
|
||||
validate_config,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -27,6 +27,13 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"learning_rate": 0.0001,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -97,7 +104,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
|
||||
def test_migrate_fsdp_config(self):
|
||||
"""Test basic FSDP config migration with and without fsdp_version"""
|
||||
cfg_with_version = DictDefault(
|
||||
cfg_with_version = self._get_base_cfg() | DictDefault(
|
||||
{
|
||||
"fsdp_config": {
|
||||
"fsdp_version": 2,
|
||||
@@ -109,7 +116,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
migrate_fsdp_config(cfg_with_version)
|
||||
cfg_with_version = validate_config(cfg_with_version)
|
||||
|
||||
self.assertEqual(cfg_with_version.fsdp_version, 2)
|
||||
self.assertEqual(
|
||||
@@ -125,7 +132,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
self.assertNotIn("fsdp_version", cfg_with_version.fsdp_config)
|
||||
self.assertNotIn("version", cfg_with_version.fsdp_config)
|
||||
|
||||
cfg_without_version = DictDefault(
|
||||
cfg_without_version = self._get_base_cfg() | DictDefault(
|
||||
{
|
||||
"fsdp_config": {
|
||||
"fsdp_auto_wrap_policy": "SIZE_BASED_WRAP",
|
||||
@@ -135,7 +142,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
migrate_fsdp_config(cfg_without_version)
|
||||
cfg_without_version = validate_config(cfg_without_version)
|
||||
|
||||
self.assertNotIn("fsdp_version", cfg_without_version)
|
||||
self.assertEqual(
|
||||
@@ -149,26 +156,25 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
|
||||
def test_migrate_fsdp_config_no_fsdp_config(self):
|
||||
"""Test that function doesn't crash when no fsdp_config is present"""
|
||||
cfg = DictDefault({"some_other_config": "value"})
|
||||
cfg = self._get_base_cfg()
|
||||
|
||||
migrate_fsdp_config(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
|
||||
self.assertNotIn("fsdp_config", cfg)
|
||||
self.assertNotIn("fsdp_version", cfg)
|
||||
self.assertEqual(cfg.some_other_config, "value")
|
||||
|
||||
def test_migrate_fsdp_config_empty_fsdp_config(self):
|
||||
"""Test migration with empty fsdp_config"""
|
||||
cfg = DictDefault({"fsdp_config": {}})
|
||||
cfg = self._get_base_cfg() | DictDefault({"fsdp_config": {}})
|
||||
|
||||
migrate_fsdp_config(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
|
||||
self.assertNotIn("fsdp_version", cfg)
|
||||
self.assertEqual(cfg.fsdp_config, {})
|
||||
|
||||
def test_migrate_fsdp_config_mixed_keys(self):
|
||||
"""Test migration with a mix of fsdp_ and non-fsdp_ keys"""
|
||||
cfg = DictDefault(
|
||||
cfg = self._get_base_cfg() | DictDefault(
|
||||
{
|
||||
"fsdp_config": {
|
||||
"fsdp_version": 1,
|
||||
@@ -180,7 +186,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
migrate_fsdp_config(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
|
||||
self.assertEqual(cfg.fsdp_version, 1)
|
||||
self.assertEqual(cfg.fsdp_config.state_dict_type, "FULL_STATE_DICT")
|
||||
|
||||
Reference in New Issue
Block a user