diff --git a/tests/conftest.py b/tests/conftest.py index 24615fa22..9e1af318d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,6 +22,8 @@ from huggingface_hub.errors import LocalEntryNotFoundError from tokenizers import AddedToken from transformers import AutoTokenizer +from axolotl.utils.dict import DictDefault + from tests.hf_offline_utils import ( enable_hf_offline, hf_offline_context, @@ -539,6 +541,22 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff( return datasets.load_from_disk(ds_path)["train"] +@pytest.fixture(name="min_base_cfg") +def fixture_min_base_cfg(): + return DictDefault( + base_model="HuggingFaceTB/SmolLM2-135M", + learning_rate=1e-3, + datasets=[ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + micro_batch_size=1, + gradient_accumulation_steps=1, + ) + + # # pylint: disable=redefined-outer-name,unused-argument @pytest.mark.skipif( os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1", diff --git a/tests/test_train.py b/tests/test_train.py index 291e9136b..2c29b58ee 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -7,21 +7,16 @@ from axolotl.utils.dict import DictDefault @pytest.fixture(name="train_base_cfg") -def fixture_train_base_cfg(): - return DictDefault( - base_model="gpt2", - learning_rate=1e-3, - datasets=[ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - micro_batch_size=2, - gradient_accumulation_steps=4, - sequence_len=2048, - sample_packing=True, - num_epochs=1, +def fixture_train_base_cfg(min_base_cfg): + return ( + DictDefault( + micro_batch_size=2, + gradient_accumulation_steps=4, + sequence_len=2048, + sample_packing=True, + num_epochs=1, + ) + | min_base_cfg ) diff --git a/tests/utils/schemas/validation/test_fsdp.py b/tests/utils/schemas/validation/test_fsdp.py index 456040bc1..67f4a5cf9 100644 --- a/tests/utils/schemas/validation/test_fsdp.py +++ b/tests/utils/schemas/validation/test_fsdp.py @@ -9,29 +9,13 @@ from axolotl.utils.config import validate_config from axolotl.utils.dict import DictDefault -@pytest.fixture(name="fsdp_base_cfg") -def fixture_fsdp_base_cfg(): - return DictDefault( - base_model="gpt2", - learning_rate=1e-3, - datasets=[ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - micro_batch_size=1, - gradient_accumulation_steps=1, - ) - - class TestFSDPValidation: """ test class for pydantic fsdp validation """ - def test_fsdp_version_in_fsdp_config(self, fsdp_base_cfg): - cfg = fsdp_base_cfg | DictDefault( + def test_fsdp_version_in_fsdp_config(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( fsdp_config={ "fsdp_version": 2, }, @@ -42,8 +26,8 @@ class TestFSDPValidation: assert cfg.fsdp_version == 2 assert cfg.fsdp_config.fsdp_version is None - def test_fsdp_sharded_state_dict_safetensors(self, fsdp_base_cfg): - cfg = fsdp_base_cfg | DictDefault( + def test_fsdp_sharded_state_dict_safetensors(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( fsdp_config={ "fsdp_state_dict_type": "SHARDED_STATE_DICT", }, @@ -56,7 +40,7 @@ class TestFSDPValidation: validate_config(cfg) # test w/o prefix too - cfg = fsdp_base_cfg | DictDefault( + cfg = min_base_cfg | DictDefault( fsdp_config={ "state_dict_type": "SHARDED_STATE_DICT", }, @@ -68,8 +52,8 @@ class TestFSDPValidation: ): validate_config(cfg) - def test_fsdp_offload_w_8bit_optim(self, fsdp_base_cfg): - cfg = fsdp_base_cfg | DictDefault( + def test_fsdp_offload_w_8bit_optim(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( fsdp_config={ "offload_params": True, }, @@ -81,8 +65,8 @@ class TestFSDPValidation: ): validate_config(cfg) - def test_fsdp2_w_8bit_optim(self, fsdp_base_cfg): - cfg = fsdp_base_cfg | DictDefault( + def test_fsdp2_w_8bit_optim(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( fsdp_config={ "offload_params": True, }, @@ -95,8 +79,8 @@ class TestFSDPValidation: ): validate_config(cfg) - def test_fsdp2_w_cpu_ram_efficient_loading(self, fsdp_base_cfg): - cfg = fsdp_base_cfg | DictDefault( + def test_fsdp2_w_cpu_ram_efficient_loading(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( load_in_8bit=True, adapter="lora", fsdp_config={ @@ -110,8 +94,8 @@ class TestFSDPValidation: ): validate_config(cfg) - def test_fsdp_prefixes_removed(self, fsdp_base_cfg): - cfg = fsdp_base_cfg | DictDefault( + def test_fsdp_prefixes_removed(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( fsdp_config={ "fsdp_version": 2, "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", @@ -137,8 +121,8 @@ class TestFSDPValidation: "ipo", ], ) - def test_fsdp2_dpo(self, fsdp_base_cfg, rl): - cfg = fsdp_base_cfg | DictDefault( + def test_fsdp2_dpo(self, min_base_cfg, rl): + cfg = min_base_cfg | DictDefault( fsdp_version=2, fsdp_config={ "reshard_after_forward": True,