refactor dupes from merge/rebase (#2919) [skip ci]

This commit is contained in:
Wing Lian
2025-07-14 10:05:26 -04:00
committed by GitHub
parent af92151a7b
commit e581c15d40
3 changed files with 43 additions and 46 deletions

View File

@@ -22,6 +22,8 @@ from huggingface_hub.errors import LocalEntryNotFoundError
from tokenizers import AddedToken from tokenizers import AddedToken
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import ( from tests.hf_offline_utils import (
enable_hf_offline, enable_hf_offline,
hf_offline_context, hf_offline_context,
@@ -539,6 +541,22 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
return datasets.load_from_disk(ds_path)["train"] return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture(name="min_base_cfg")
def fixture_min_base_cfg():
return DictDefault(
base_model="HuggingFaceTB/SmolLM2-135M",
learning_rate=1e-3,
datasets=[
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
micro_batch_size=1,
gradient_accumulation_steps=1,
)
# # pylint: disable=redefined-outer-name,unused-argument # # pylint: disable=redefined-outer-name,unused-argument
@pytest.mark.skipif( @pytest.mark.skipif(
os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1", os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1",

View File

@@ -7,21 +7,16 @@ from axolotl.utils.dict import DictDefault
@pytest.fixture(name="train_base_cfg") @pytest.fixture(name="train_base_cfg")
def fixture_train_base_cfg(): def fixture_train_base_cfg(min_base_cfg):
return DictDefault( return (
base_model="gpt2", DictDefault(
learning_rate=1e-3, micro_batch_size=2,
datasets=[ gradient_accumulation_steps=4,
{ sequence_len=2048,
"path": "mhenrichsen/alpaca_2k_test", sample_packing=True,
"type": "alpaca", num_epochs=1,
}, )
], | min_base_cfg
micro_batch_size=2,
gradient_accumulation_steps=4,
sequence_len=2048,
sample_packing=True,
num_epochs=1,
) )

View File

@@ -9,29 +9,13 @@ from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@pytest.fixture(name="fsdp_base_cfg")
def fixture_fsdp_base_cfg():
return DictDefault(
base_model="gpt2",
learning_rate=1e-3,
datasets=[
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
micro_batch_size=1,
gradient_accumulation_steps=1,
)
class TestFSDPValidation: class TestFSDPValidation:
""" """
test class for pydantic fsdp validation test class for pydantic fsdp validation
""" """
def test_fsdp_version_in_fsdp_config(self, fsdp_base_cfg): def test_fsdp_version_in_fsdp_config(self, min_base_cfg):
cfg = fsdp_base_cfg | DictDefault( cfg = min_base_cfg | DictDefault(
fsdp_config={ fsdp_config={
"fsdp_version": 2, "fsdp_version": 2,
}, },
@@ -42,8 +26,8 @@ class TestFSDPValidation:
assert cfg.fsdp_version == 2 assert cfg.fsdp_version == 2
assert cfg.fsdp_config.fsdp_version is None assert cfg.fsdp_config.fsdp_version is None
def test_fsdp_sharded_state_dict_safetensors(self, fsdp_base_cfg): def test_fsdp_sharded_state_dict_safetensors(self, min_base_cfg):
cfg = fsdp_base_cfg | DictDefault( cfg = min_base_cfg | DictDefault(
fsdp_config={ fsdp_config={
"fsdp_state_dict_type": "SHARDED_STATE_DICT", "fsdp_state_dict_type": "SHARDED_STATE_DICT",
}, },
@@ -56,7 +40,7 @@ class TestFSDPValidation:
validate_config(cfg) validate_config(cfg)
# test w/o prefix too # test w/o prefix too
cfg = fsdp_base_cfg | DictDefault( cfg = min_base_cfg | DictDefault(
fsdp_config={ fsdp_config={
"state_dict_type": "SHARDED_STATE_DICT", "state_dict_type": "SHARDED_STATE_DICT",
}, },
@@ -68,8 +52,8 @@ class TestFSDPValidation:
): ):
validate_config(cfg) validate_config(cfg)
def test_fsdp_offload_w_8bit_optim(self, fsdp_base_cfg): def test_fsdp_offload_w_8bit_optim(self, min_base_cfg):
cfg = fsdp_base_cfg | DictDefault( cfg = min_base_cfg | DictDefault(
fsdp_config={ fsdp_config={
"offload_params": True, "offload_params": True,
}, },
@@ -81,8 +65,8 @@ class TestFSDPValidation:
): ):
validate_config(cfg) validate_config(cfg)
def test_fsdp2_w_8bit_optim(self, fsdp_base_cfg): def test_fsdp2_w_8bit_optim(self, min_base_cfg):
cfg = fsdp_base_cfg | DictDefault( cfg = min_base_cfg | DictDefault(
fsdp_config={ fsdp_config={
"offload_params": True, "offload_params": True,
}, },
@@ -95,8 +79,8 @@ class TestFSDPValidation:
): ):
validate_config(cfg) validate_config(cfg)
def test_fsdp2_w_cpu_ram_efficient_loading(self, fsdp_base_cfg): def test_fsdp2_w_cpu_ram_efficient_loading(self, min_base_cfg):
cfg = fsdp_base_cfg | DictDefault( cfg = min_base_cfg | DictDefault(
load_in_8bit=True, load_in_8bit=True,
adapter="lora", adapter="lora",
fsdp_config={ fsdp_config={
@@ -110,8 +94,8 @@ class TestFSDPValidation:
): ):
validate_config(cfg) validate_config(cfg)
def test_fsdp_prefixes_removed(self, fsdp_base_cfg): def test_fsdp_prefixes_removed(self, min_base_cfg):
cfg = fsdp_base_cfg | DictDefault( cfg = min_base_cfg | DictDefault(
fsdp_config={ fsdp_config={
"fsdp_version": 2, "fsdp_version": 2,
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
@@ -137,8 +121,8 @@ class TestFSDPValidation:
"ipo", "ipo",
], ],
) )
def test_fsdp2_dpo(self, fsdp_base_cfg, rl): def test_fsdp2_dpo(self, min_base_cfg, rl):
cfg = fsdp_base_cfg | DictDefault( cfg = min_base_cfg | DictDefault(
fsdp_version=2, fsdp_version=2,
fsdp_config={ fsdp_config={
"reshard_after_forward": True, "reshard_after_forward": True,