refactor dupes from merge/rebase (#2919) [skip ci]
This commit is contained in:
@@ -22,6 +22,8 @@ from huggingface_hub.errors import LocalEntryNotFoundError
|
|||||||
from tokenizers import AddedToken
|
from tokenizers import AddedToken
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.hf_offline_utils import (
|
from tests.hf_offline_utils import (
|
||||||
enable_hf_offline,
|
enable_hf_offline,
|
||||||
hf_offline_context,
|
hf_offline_context,
|
||||||
@@ -539,6 +541,22 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
|
|||||||
return datasets.load_from_disk(ds_path)["train"]
|
return datasets.load_from_disk(ds_path)["train"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="min_base_cfg")
|
||||||
|
def fixture_min_base_cfg():
|
||||||
|
return DictDefault(
|
||||||
|
base_model="HuggingFaceTB/SmolLM2-135M",
|
||||||
|
learning_rate=1e-3,
|
||||||
|
datasets=[
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
micro_batch_size=1,
|
||||||
|
gradient_accumulation_steps=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# # pylint: disable=redefined-outer-name,unused-argument
|
# # pylint: disable=redefined-outer-name,unused-argument
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1",
|
os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1",
|
||||||
|
|||||||
@@ -7,21 +7,16 @@ from axolotl.utils.dict import DictDefault
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="train_base_cfg")
|
@pytest.fixture(name="train_base_cfg")
|
||||||
def fixture_train_base_cfg():
|
def fixture_train_base_cfg(min_base_cfg):
|
||||||
return DictDefault(
|
return (
|
||||||
base_model="gpt2",
|
DictDefault(
|
||||||
learning_rate=1e-3,
|
micro_batch_size=2,
|
||||||
datasets=[
|
gradient_accumulation_steps=4,
|
||||||
{
|
sequence_len=2048,
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
sample_packing=True,
|
||||||
"type": "alpaca",
|
num_epochs=1,
|
||||||
},
|
)
|
||||||
],
|
| min_base_cfg
|
||||||
micro_batch_size=2,
|
|
||||||
gradient_accumulation_steps=4,
|
|
||||||
sequence_len=2048,
|
|
||||||
sample_packing=True,
|
|
||||||
num_epochs=1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,29 +9,13 @@ from axolotl.utils.config import validate_config
|
|||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="fsdp_base_cfg")
|
|
||||||
def fixture_fsdp_base_cfg():
|
|
||||||
return DictDefault(
|
|
||||||
base_model="gpt2",
|
|
||||||
learning_rate=1e-3,
|
|
||||||
datasets=[
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
micro_batch_size=1,
|
|
||||||
gradient_accumulation_steps=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestFSDPValidation:
|
class TestFSDPValidation:
|
||||||
"""
|
"""
|
||||||
test class for pydantic fsdp validation
|
test class for pydantic fsdp validation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_fsdp_version_in_fsdp_config(self, fsdp_base_cfg):
|
def test_fsdp_version_in_fsdp_config(self, min_base_cfg):
|
||||||
cfg = fsdp_base_cfg | DictDefault(
|
cfg = min_base_cfg | DictDefault(
|
||||||
fsdp_config={
|
fsdp_config={
|
||||||
"fsdp_version": 2,
|
"fsdp_version": 2,
|
||||||
},
|
},
|
||||||
@@ -42,8 +26,8 @@ class TestFSDPValidation:
|
|||||||
assert cfg.fsdp_version == 2
|
assert cfg.fsdp_version == 2
|
||||||
assert cfg.fsdp_config.fsdp_version is None
|
assert cfg.fsdp_config.fsdp_version is None
|
||||||
|
|
||||||
def test_fsdp_sharded_state_dict_safetensors(self, fsdp_base_cfg):
|
def test_fsdp_sharded_state_dict_safetensors(self, min_base_cfg):
|
||||||
cfg = fsdp_base_cfg | DictDefault(
|
cfg = min_base_cfg | DictDefault(
|
||||||
fsdp_config={
|
fsdp_config={
|
||||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||||
},
|
},
|
||||||
@@ -56,7 +40,7 @@ class TestFSDPValidation:
|
|||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
# test w/o prefix too
|
# test w/o prefix too
|
||||||
cfg = fsdp_base_cfg | DictDefault(
|
cfg = min_base_cfg | DictDefault(
|
||||||
fsdp_config={
|
fsdp_config={
|
||||||
"state_dict_type": "SHARDED_STATE_DICT",
|
"state_dict_type": "SHARDED_STATE_DICT",
|
||||||
},
|
},
|
||||||
@@ -68,8 +52,8 @@ class TestFSDPValidation:
|
|||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
def test_fsdp_offload_w_8bit_optim(self, fsdp_base_cfg):
|
def test_fsdp_offload_w_8bit_optim(self, min_base_cfg):
|
||||||
cfg = fsdp_base_cfg | DictDefault(
|
cfg = min_base_cfg | DictDefault(
|
||||||
fsdp_config={
|
fsdp_config={
|
||||||
"offload_params": True,
|
"offload_params": True,
|
||||||
},
|
},
|
||||||
@@ -81,8 +65,8 @@ class TestFSDPValidation:
|
|||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
def test_fsdp2_w_8bit_optim(self, fsdp_base_cfg):
|
def test_fsdp2_w_8bit_optim(self, min_base_cfg):
|
||||||
cfg = fsdp_base_cfg | DictDefault(
|
cfg = min_base_cfg | DictDefault(
|
||||||
fsdp_config={
|
fsdp_config={
|
||||||
"offload_params": True,
|
"offload_params": True,
|
||||||
},
|
},
|
||||||
@@ -95,8 +79,8 @@ class TestFSDPValidation:
|
|||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
def test_fsdp2_w_cpu_ram_efficient_loading(self, fsdp_base_cfg):
|
def test_fsdp2_w_cpu_ram_efficient_loading(self, min_base_cfg):
|
||||||
cfg = fsdp_base_cfg | DictDefault(
|
cfg = min_base_cfg | DictDefault(
|
||||||
load_in_8bit=True,
|
load_in_8bit=True,
|
||||||
adapter="lora",
|
adapter="lora",
|
||||||
fsdp_config={
|
fsdp_config={
|
||||||
@@ -110,8 +94,8 @@ class TestFSDPValidation:
|
|||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
def test_fsdp_prefixes_removed(self, fsdp_base_cfg):
|
def test_fsdp_prefixes_removed(self, min_base_cfg):
|
||||||
cfg = fsdp_base_cfg | DictDefault(
|
cfg = min_base_cfg | DictDefault(
|
||||||
fsdp_config={
|
fsdp_config={
|
||||||
"fsdp_version": 2,
|
"fsdp_version": 2,
|
||||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
@@ -137,8 +121,8 @@ class TestFSDPValidation:
|
|||||||
"ipo",
|
"ipo",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_fsdp2_dpo(self, fsdp_base_cfg, rl):
|
def test_fsdp2_dpo(self, min_base_cfg, rl):
|
||||||
cfg = fsdp_base_cfg | DictDefault(
|
cfg = min_base_cfg | DictDefault(
|
||||||
fsdp_version=2,
|
fsdp_version=2,
|
||||||
fsdp_config={
|
fsdp_config={
|
||||||
"reshard_after_forward": True,
|
"reshard_after_forward": True,
|
||||||
|
|||||||
Reference in New Issue
Block a user