refactor dupes from merge/rebase (#2919) [skip ci]
This commit is contained in:
@@ -9,29 +9,13 @@ from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="fsdp_base_cfg")
|
||||
def fixture_fsdp_base_cfg():
|
||||
return DictDefault(
|
||||
base_model="gpt2",
|
||||
learning_rate=1e-3,
|
||||
datasets=[
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
micro_batch_size=1,
|
||||
gradient_accumulation_steps=1,
|
||||
)
|
||||
|
||||
|
||||
class TestFSDPValidation:
|
||||
"""
|
||||
test class for pydantic fsdp validation
|
||||
"""
|
||||
|
||||
def test_fsdp_version_in_fsdp_config(self, fsdp_base_cfg):
|
||||
cfg = fsdp_base_cfg | DictDefault(
|
||||
def test_fsdp_version_in_fsdp_config(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
"fsdp_version": 2,
|
||||
},
|
||||
@@ -42,8 +26,8 @@ class TestFSDPValidation:
|
||||
assert cfg.fsdp_version == 2
|
||||
assert cfg.fsdp_config.fsdp_version is None
|
||||
|
||||
def test_fsdp_sharded_state_dict_safetensors(self, fsdp_base_cfg):
|
||||
cfg = fsdp_base_cfg | DictDefault(
|
||||
def test_fsdp_sharded_state_dict_safetensors(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||
},
|
||||
@@ -56,7 +40,7 @@ class TestFSDPValidation:
|
||||
validate_config(cfg)
|
||||
|
||||
# test w/o prefix too
|
||||
cfg = fsdp_base_cfg | DictDefault(
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
"state_dict_type": "SHARDED_STATE_DICT",
|
||||
},
|
||||
@@ -68,8 +52,8 @@ class TestFSDPValidation:
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_fsdp_offload_w_8bit_optim(self, fsdp_base_cfg):
|
||||
cfg = fsdp_base_cfg | DictDefault(
|
||||
def test_fsdp_offload_w_8bit_optim(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
"offload_params": True,
|
||||
},
|
||||
@@ -81,8 +65,8 @@ class TestFSDPValidation:
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_fsdp2_w_8bit_optim(self, fsdp_base_cfg):
|
||||
cfg = fsdp_base_cfg | DictDefault(
|
||||
def test_fsdp2_w_8bit_optim(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
"offload_params": True,
|
||||
},
|
||||
@@ -95,8 +79,8 @@ class TestFSDPValidation:
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_fsdp2_w_cpu_ram_efficient_loading(self, fsdp_base_cfg):
|
||||
cfg = fsdp_base_cfg | DictDefault(
|
||||
def test_fsdp2_w_cpu_ram_efficient_loading(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
load_in_8bit=True,
|
||||
adapter="lora",
|
||||
fsdp_config={
|
||||
@@ -110,8 +94,8 @@ class TestFSDPValidation:
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_fsdp_prefixes_removed(self, fsdp_base_cfg):
|
||||
cfg = fsdp_base_cfg | DictDefault(
|
||||
def test_fsdp_prefixes_removed(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
"fsdp_version": 2,
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
@@ -137,8 +121,8 @@ class TestFSDPValidation:
|
||||
"ipo",
|
||||
],
|
||||
)
|
||||
def test_fsdp2_dpo(self, fsdp_base_cfg, rl):
|
||||
cfg = fsdp_base_cfg | DictDefault(
|
||||
def test_fsdp2_dpo(self, min_base_cfg, rl):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_version=2,
|
||||
fsdp_config={
|
||||
"reshard_after_forward": True,
|
||||
|
||||
Reference in New Issue
Block a user