From babf0fdb710de86049ade89d6874232445dfc07e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 9 Jun 2023 00:29:04 +0900 Subject: [PATCH 1/2] Validate falcon with fsdp --- src/axolotl/utils/validation.py | 3 +++ tests/test_validation.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 38e0b9819..367178719 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -54,6 +54,9 @@ def validate_config(cfg): "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub" ) + if "falcon" in cfg.base_model.lower() and cfg.fsdp: + raise ValueError("FSDP is not supported for falcon models") + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/tests/test_validation.py b/tests/test_validation.py index ce744f762..50bdf37e6 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -165,3 +165,36 @@ class ValidationTest(unittest.TestCase): ) validate_config(cfg) + + def test_falcon_fsdp(self): + regex_exp = r".*FSDP is not supported for falcon models.*" + + # Check for lower-case + cfg = DictDefault( + { + "base_model": "tiiuae/falcon-7b", + "fsdp": ["full_shard", "auto_wrap"], + } + ) + + with pytest.raises(ValueError, match=regex_exp): + validate_config(cfg) + + # Check for upper-case + cfg = DictDefault( + { + "base_model": "Falcon-7b", + "fsdp": ["full_shard", "auto_wrap"], + } + ) + + with pytest.raises(ValueError, match=regex_exp): + validate_config(cfg) + + cfg = DictDefault( + { + "base_model": "tiiuae/falcon-7b", + } + ) + + validate_config(cfg) From bfd27ba55efb181cca7d9a86308bfd53d6c54272 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 9 Jun 2023 00:35:03 +0900 Subject: [PATCH 2/2] Fix failing test --- src/axolotl/utils/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 367178719..04ffc4c1b 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -54,7 +54,7 @@ def validate_config(cfg): "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub" ) - if "falcon" in cfg.base_model.lower() and cfg.fsdp: + if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp: raise ValueError("FSDP is not supported for falcon models") # TODO