diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 38e0b9819..367178719 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -54,6 +54,9 @@ def validate_config(cfg): "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub" ) + if "falcon" in cfg.base_model.lower() and cfg.fsdp: + raise ValueError("FSDP is not supported for falcon models") + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/tests/test_validation.py b/tests/test_validation.py index ce744f762..50bdf37e6 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -165,3 +165,36 @@ class ValidationTest(unittest.TestCase): ) validate_config(cfg) + + def test_falcon_fsdp(self): + regex_exp = r".*FSDP is not supported for falcon models.*" + + # Check for lower-case + cfg = DictDefault( + { + "base_model": "tiiuae/falcon-7b", + "fsdp": ["full_shard", "auto_wrap"], + } + ) + + with pytest.raises(ValueError, match=regex_exp): + validate_config(cfg) + + # Check for upper-case + cfg = DictDefault( + { + "base_model": "Falcon-7b", + "fsdp": ["full_shard", "auto_wrap"], + } + ) + + with pytest.raises(ValueError, match=regex_exp): + validate_config(cfg) + + cfg = DictDefault( + { + "base_model": "tiiuae/falcon-7b", + } + ) + + validate_config(cfg)