Validate falcon with fsdp

This commit is contained in:
NanoCode012
2023-06-09 00:29:04 +09:00
parent 81911d112c
commit babf0fdb71
2 changed files with 36 additions and 0 deletions

View File

@@ -54,6 +54,9 @@ def validate_config(cfg):
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
)
if "falcon" in cfg.base_model.lower() and cfg.fsdp:
raise ValueError("FSDP is not supported for falcon models")
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -165,3 +165,36 @@ class ValidationTest(unittest.TestCase):
)
validate_config(cfg)
def test_falcon_fsdp(self):
regex_exp = r".*FSDP is not supported for falcon models.*"
# Check for lower-case
cfg = DictDefault(
{
"base_model": "tiiuae/falcon-7b",
"fsdp": ["full_shard", "auto_wrap"],
}
)
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
# Check for upper-case
cfg = DictDefault(
{
"base_model": "Falcon-7b",
"fsdp": ["full_shard", "auto_wrap"],
}
)
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
cfg = DictDefault(
{
"base_model": "tiiuae/falcon-7b",
}
)
validate_config(cfg)