better handling since xgen tokenizer breaks with convert_tokens_to_ids

This commit is contained in:
Wing Lian
2023-07-21 09:24:11 -04:00
parent 06c61d6f13
commit 2a428e8014

View File

@@ -48,16 +48,22 @@ class PromptTokenizingStrategy(abc.ABC):
@functools.lru_cache(maxsize=128) @functools.lru_cache(maxsize=128)
def _get_user_token(self): def _get_user_token(self):
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>") try:
if isinstance(id_or_ids, (int,)): id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
return id_or_ids if isinstance(id_or_ids, (int,)):
return id_or_ids
except KeyError:
pass
return False return False
@functools.lru_cache(maxsize=128) @functools.lru_cache(maxsize=128)
def _get_assistant_token(self): def _get_assistant_token(self):
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>") try:
if isinstance(id_or_ids, (int,)): id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
return id_or_ids if isinstance(id_or_ids, (int,)):
return id_or_ids
except KeyError:
pass
return False return False
def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False): def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):