diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 1b4ce9246..d9e56b95a 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -448,6 +448,20 @@ def validate_config(cfg): if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0: raise ValueError("neftune_noise_alpha must be > 0.0") + if ( + cfg.adapter + and cfg.tokens + and ( + not cfg.lora_modules_to_save + or not all( + x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"] + ) + ) + ): + raise ValueError( + "lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`." + ) + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 022229af8..8cb9e8426 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -136,6 +136,23 @@ def load_tokenizer(cfg): if cfg.special_tokens: for k, val in cfg.special_tokens.items(): + # check if new special token is not already in tokenizer and + # is adapter training to make sure lora_modules_to_save is set + if ( + (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val) + and cfg.adapter + and ( + not cfg.lora_modules_to_save + or not all( + x in cfg.lora_modules_to_save + for x in ["embed_tokens", "lm_head"] + ) + ) + ): + raise ValueError( + "Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens." + ) + tokenizer.add_special_tokens( {k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)} ) diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py index 5c8339194..bfe4f06af 100644 --- a/tests/test_tokenizers.py +++ b/tests/test_tokenizers.py @@ -3,6 +3,8 @@ Test cases for the tokenizer loading """ import unittest +import pytest + from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_tokenizer @@ -31,6 +33,40 @@ class TestTokenizers(unittest.TestCase): tokenizer = load_tokenizer(cfg) assert "Fast" not in tokenizer.__class__.__name__ + def test_special_tokens_modules_to_save(self): + # setting special_tokens to new token + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "adapter": "lora", + "special_tokens": {"bos_token": "[INST]"}, + } + ) + with pytest.raises( + ValueError, + match=r".*Please set lora_modules_to_save*", + ): + load_tokenizer(cfg) + + # setting special_tokens but not changing from default + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "adapter": "lora", + "special_tokens": {"bos_token": ""}, + } + ) + load_tokenizer(cfg) + + # non-adapter setting special_tokens + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "special_tokens": {"bos_token": "[INST]"}, + } + ) + load_tokenizer(cfg) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_validation.py b/tests/test_validation.py index fabc23da3..12997b023 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -682,6 +682,43 @@ class ValidationTest(unittest.TestCase): validate_config(cfg) + def test_add_tokens_adapter(self): + cfg = DictDefault( + {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]} + ) + + with pytest.raises( + ValueError, + match=r".*lora_modules_to_save not properly set yet adding new tokens*", + ): + validate_config(cfg) + + cfg = DictDefault( + { + "adapter": "qlora", + "load_in_4bit": True, + "tokens": ["<|imstart|>"], + "lora_modules_to_save": ["embed_tokens"], + } + ) + + with pytest.raises( + ValueError, + match=r".*lora_modules_to_save not properly set yet adding new tokens*", + ): + validate_config(cfg) + + cfg = DictDefault( + { + "adapter": "qlora", + "load_in_4bit": True, + "tokens": ["<|imstart|>"], + "lora_modules_to_save": ["embed_tokens", "lm_head"], + } + ) + + validate_config(cfg) + class ValidationWandbTest(ValidationTest): """