From 14668fa54ec8c35771d50ff7956cbb6541e81f6a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 11 Jun 2023 09:26:10 -0400 Subject: [PATCH] new validation for mpt w grad checkpoints --- src/axolotl/utils/validation.py | 5 +++++ tests/test_validation.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 04ffc4c1b..e2d0b34b1 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -57,6 +57,11 @@ def validate_config(cfg): if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp: raise ValueError("FSDP is not supported for falcon models") + if ( + cfg.base_model and "mpt" in cfg.base_model.lower() + ) and cfg.gradient_checkpointing: + raise ValueError("gradient_checkpointing is not supported for MPT models") + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/tests/test_validation.py b/tests/test_validation.py index 50bdf37e6..e28891060 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -198,3 +198,17 @@ class ValidationTest(unittest.TestCase): ) validate_config(cfg) + + def test_mpt_gradient_checkpointing(self): + regex_exp = r".*gradient_checkpointing is not supported for MPT models*" + + # Check for lower-case + cfg = DictDefault( + { + "base_model": "mosaicml/mpt-7b", + "gradient_checkpointing": True, + } + ) + + with pytest.raises(ValueError, match=regex_exp): + validate_config(cfg)