diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index fb6f39b0d..e223b6d76 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -48,16 +48,22 @@ class PromptTokenizingStrategy(abc.ABC): @functools.lru_cache(maxsize=128) def _get_user_token(self): - id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>") - if isinstance(id_or_ids, (int,)): - return id_or_ids + try: + id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>") + if isinstance(id_or_ids, (int,)): + return id_or_ids + except KeyError: + pass return False @functools.lru_cache(maxsize=128) def _get_assistant_token(self): - id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>") - if isinstance(id_or_ids, (int,)): - return id_or_ids + try: + id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>") + if isinstance(id_or_ids, (int,)): + return id_or_ids + except KeyError: + pass return False def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):