Support for additional_special_tokens (#1221) [skip ci]
* Support for additional_special_tokens * Support for additional_special_tokens. Adjust whitespace. * Support for additional_special_tokens. Use correct quotes. * Support for additional_special_tokens. Safe pop. * Support for additional_special_tokens. nt. * Support for additional_special_tokens. cfg.special_tokens may be None. * add token if not in vocabulary when adding additional_special_tokens * fix logic for copy/pasta * bugfix for popping from config and tokenizer reload * no need to add tokens manually now with previous bugfix --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -161,15 +161,20 @@ def load_tokenizer(cfg):
|
||||
if getattr(tokenizer, attr_name) is None:
|
||||
setattr(tokenizer, attr_name, "<|endoftext|>")
|
||||
|
||||
additional_special_tokens = None
|
||||
if cfg.special_tokens:
|
||||
special_tokens = cfg.special_tokens.to_dict()
|
||||
additional_special_tokens = special_tokens.pop(
|
||||
"additional_special_tokens", None
|
||||
)
|
||||
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
|
||||
for k, val in cfg.special_tokens.items():
|
||||
for k, val in special_tokens.items():
|
||||
# check if new special token is not already in tokenizer and
|
||||
# is adapter training to make sure lora_modules_to_save is set
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if (
|
||||
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
|
||||
and (len(tokenizer.encode(val)) > 1)
|
||||
and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
|
||||
and cfg.adapter
|
||||
and (
|
||||
not cfg.lora_modules_to_save
|
||||
@@ -213,6 +218,21 @@ def load_tokenizer(cfg):
|
||||
]
|
||||
)
|
||||
|
||||
# Additional special tokens are a List, and need to be treated differently than regular special
|
||||
# tokens. We add them after we have called `add_tokens` in case these additional special tokens
|
||||
# are new tokens.
|
||||
#
|
||||
# Usage:
|
||||
#
|
||||
# ```py
|
||||
# special_tokens:
|
||||
# additional_special_tokens: ["<|im_start|>", "<|im_end|>"]
|
||||
# ```
|
||||
if additional_special_tokens is not None:
|
||||
tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": additional_special_tokens}
|
||||
)
|
||||
|
||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
|
||||
@@ -67,6 +67,21 @@ class TestTokenizers(unittest.TestCase):
|
||||
)
|
||||
load_tokenizer(cfg)
|
||||
|
||||
def test_add_additional_special_tokens(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"special_tokens": {"additional_special_tokens": ["<|im_start|>"]},
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
|
||||
self.assertEqual(len(tokenizer), 32001)
|
||||
|
||||
# ensure reloading the tokenizer again from cfg results in same vocab length
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
self.assertEqual(len(tokenizer), 32001)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user