fix None-type not iterable error when deepspeed is left blank w/ use_… (#2087)
* fix None-type not iterable error when deepspeed is left blank w/ use_reentrant: false and qlora * added unit test[skip e2e] * corrected test case[skip e2e] * assert warning message [skip e2e] * assert warning message [skip e2e] * corrected test cases [skip e2e] * lint
This commit is contained in:
@@ -1314,6 +1314,7 @@ class AxolotlInputConfig(
|
|||||||
and data.get("gradient_checkpointing_kwargs", {})
|
and data.get("gradient_checkpointing_kwargs", {})
|
||||||
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
|
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
|
||||||
is False
|
is False
|
||||||
|
and data.get("deepspeed", "") is not None
|
||||||
and "zero3" in data.get("deepspeed", "")
|
and "zero3" in data.get("deepspeed", "")
|
||||||
):
|
):
|
||||||
# may result in:
|
# may result in:
|
||||||
|
|||||||
@@ -68,6 +68,53 @@ class TestValidation(BaseValidation):
|
|||||||
assert cfg.train_on_inputs is False
|
assert cfg.train_on_inputs is False
|
||||||
assert cfg.weight_decay is None
|
assert cfg.weight_decay is None
|
||||||
|
|
||||||
|
def test_zero3_qlora_use_reentrant_false(self, minimal_cfg):
|
||||||
|
test_cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"deepspeed": "deepspeed_configs/zero3_bf16.json",
|
||||||
|
"gradient_checkpointing": True,
|
||||||
|
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"adapter": "qlora",
|
||||||
|
}
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(test_cfg)
|
||||||
|
assert (
|
||||||
|
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
|
||||||
|
in self._caplog.records[0].message
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_deepspeed_empty(self, minimal_cfg):
|
||||||
|
test_cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"deepspeed": "",
|
||||||
|
"gradient_checkpointing": True,
|
||||||
|
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"adapter": "qlora",
|
||||||
|
}
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = validate_config(test_cfg)
|
||||||
|
|
||||||
|
def test_deepspeed_not_set(self, minimal_cfg):
|
||||||
|
test_cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"deepspeed": None,
|
||||||
|
"gradient_checkpointing": True,
|
||||||
|
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"adapter": "qlora",
|
||||||
|
}
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = validate_config(test_cfg)
|
||||||
|
|
||||||
def test_datasets_min_length(self):
|
def test_datasets_min_length(self):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user