diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c12518b2d..bee6af373 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -161,15 +161,20 @@ def load_tokenizer(cfg): if getattr(tokenizer, attr_name) is None: setattr(tokenizer, attr_name, "<|endoftext|>") + additional_special_tokens = None if cfg.special_tokens: + special_tokens = cfg.special_tokens.to_dict() + additional_special_tokens = special_tokens.pop( + "additional_special_tokens", None + ) lora_modules_to_save = get_linear_embedding_layers(model_config.model_type) - for k, val in cfg.special_tokens.items(): + for k, val in special_tokens.items(): # check if new special token is not already in tokenizer and # is adapter training to make sure lora_modules_to_save is set # pylint: disable=too-many-boolean-expressions if ( (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val) - and (len(tokenizer.encode(val)) > 1) + and (len(tokenizer.encode(val, add_special_tokens=False)) > 2) and cfg.adapter and ( not cfg.lora_modules_to_save @@ -213,6 +218,21 @@ def load_tokenizer(cfg): ] ) + # Additional special tokens are a List, and need to be treated differently than regular special + # tokens. We add them after we have called `add_tokens` in case these additional special tokens + # are new tokens. + # + # Usage: + # + # ```py + # special_tokens: + # additional_special_tokens: ["<|im_start|>", "<|im_end|>"] + # ``` + if additional_special_tokens is not None: + tokenizer.add_special_tokens( + {"additional_special_tokens": additional_special_tokens} + ) + LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py index bfe4f06af..69c441f8c 100644 --- a/tests/test_tokenizers.py +++ b/tests/test_tokenizers.py @@ -67,6 +67,21 @@ class TestTokenizers(unittest.TestCase): ) load_tokenizer(cfg) + def test_add_additional_special_tokens(self): + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "special_tokens": {"additional_special_tokens": ["<|im_start|>"]}, + } + ) + tokenizer = load_tokenizer(cfg) + self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404]) + self.assertEqual(len(tokenizer), 32001) + + # ensure reloading the tokenizer again from cfg results in same vocab length + tokenizer = load_tokenizer(cfg) + self.assertEqual(len(tokenizer), 32001) + if __name__ == "__main__": unittest.main()