Merge branch 'main' into flash-optimum

This commit is contained in:
Wing Lian
2023-06-12 13:12:15 -04:00
committed by GitHub
36 changed files with 461 additions and 1009 deletions

View File

@@ -199,6 +199,20 @@ class ValidationTest(unittest.TestCase):
validate_config(cfg)
def test_mpt_gradient_checkpointing(self):
regex_exp = r".*gradient_checkpointing is not supported for MPT models*"
# Check for lower-case
cfg = DictDefault(
{
"base_model": "mosaicml/mpt-7b",
"gradient_checkpointing": True,
}
)
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
def test_flash_optimum(self):
cfg = DictDefault(
{