fix None-type not iterable error when deepspeed is left blank w/ use_… (#2087)
* fix None-type not iterable error when deepspeed is left blank w/ use_reentrant: false and qlora * added unit test[skip e2e] * corrected test case[skip e2e] * assert warning message [skip e2e] * assert warning message [skip e2e] * corrected test cases [skip e2e] * lint
This commit is contained in:
@@ -1314,6 +1314,7 @@ class AxolotlInputConfig(
|
||||
and data.get("gradient_checkpointing_kwargs", {})
|
||||
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
|
||||
is False
|
||||
and data.get("deepspeed", "") is not None
|
||||
and "zero3" in data.get("deepspeed", "")
|
||||
):
|
||||
# may result in:
|
||||
|
||||
@@ -68,6 +68,53 @@ class TestValidation(BaseValidation):
|
||||
assert cfg.train_on_inputs is False
|
||||
assert cfg.weight_decay is None
|
||||
|
||||
def test_zero3_qlora_use_reentrant_false(self, minimal_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"deepspeed": "deepspeed_configs/zero3_bf16.json",
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
}
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(test_cfg)
|
||||
assert (
|
||||
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
|
||||
in self._caplog.records[0].message
|
||||
)
|
||||
|
||||
def test_deepspeed_empty(self, minimal_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"deepspeed": "",
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
}
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
_ = validate_config(test_cfg)
|
||||
|
||||
def test_deepspeed_not_set(self, minimal_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"deepspeed": None,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
}
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
_ = validate_config(test_cfg)
|
||||
|
||||
def test_datasets_min_length(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user