Merge pull request #164 from NanoCode012/fix/falcon-fsdp-validate
Fix: Validate falcon with fsdp
This commit is contained in:
@@ -54,6 +54,9 @@ def validate_config(cfg):
|
|||||||
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
||||||
|
raise ValueError("FSDP is not supported for falcon models")
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -165,3 +165,36 @@ class ValidationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_falcon_fsdp(self):
|
||||||
|
regex_exp = r".*FSDP is not supported for falcon models.*"
|
||||||
|
|
||||||
|
# Check for lower-case
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "tiiuae/falcon-7b",
|
||||||
|
"fsdp": ["full_shard", "auto_wrap"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=regex_exp):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
# Check for upper-case
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "Falcon-7b",
|
||||||
|
"fsdp": ["full_shard", "auto_wrap"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=regex_exp):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "tiiuae/falcon-7b",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user