new validation for mpt w grad checkpoints

This commit is contained in:
Wing Lian
2023-06-11 09:26:10 -04:00
parent fe0b76854e
commit 14668fa54e
2 changed files with 19 additions and 0 deletions

View File

@@ -198,3 +198,17 @@ class ValidationTest(unittest.TestCase):
)
validate_config(cfg)
def test_mpt_gradient_checkpointing(self):
regex_exp = r".*gradient_checkpointing is not supported for MPT models*"
# Check for lower-case
cfg = DictDefault(
{
"base_model": "mosaicml/mpt-7b",
"gradient_checkpointing": True,
}
)
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)