new validation for mpt w grad checkpoints

This commit is contained in:
Wing Lian
2023-06-11 09:26:10 -04:00
parent fe0b76854e
commit 14668fa54e
2 changed files with 19 additions and 0 deletions

View File

@@ -57,6 +57,11 @@ def validate_config(cfg):
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
raise ValueError("FSDP is not supported for falcon models")
if (
cfg.base_model and "mpt" in cfg.base_model.lower()
) and cfg.gradient_checkpointing:
raise ValueError("gradient_checkpointing is not supported for MPT models")
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -198,3 +198,17 @@ class ValidationTest(unittest.TestCase):
)
validate_config(cfg)
def test_mpt_gradient_checkpointing(self):
regex_exp = r".*gradient_checkpointing is not supported for MPT models*"
# Check for lower-case
cfg = DictDefault(
{
"base_model": "mosaicml/mpt-7b",
"gradient_checkpointing": True,
}
)
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)