From 2a428e8014d1487a5a54c2b4bbf8b031fdc999d9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 21 Jul 2023 09:24:11 -0400 Subject: [PATCH] better handling since xgen tokenizer breaks with convert_tokens_to_ids --- src/axolotl/prompt_tokenizers.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index fb6f39b0d..e223b6d76 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -48,16 +48,22 @@ class PromptTokenizingStrategy(abc.ABC): @functools.lru_cache(maxsize=128) def _get_user_token(self): - id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>") - if isinstance(id_or_ids, (int,)): - return id_or_ids + try: + id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>") + if isinstance(id_or_ids, (int,)): + return id_or_ids + except KeyError: + pass return False @functools.lru_cache(maxsize=128) def _get_assistant_token(self): - id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>") - if isinstance(id_or_ids, (int,)): - return id_or_ids + try: + id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>") + if isinstance(id_or_ids, (int,)): + return id_or_ids + except KeyError: + pass return False def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):