Merge pull request #307 from OpenAccess-AI-Collective/xgen-user-sharegpt-tokens

better handling since xgen tokenizer breaks with convert_tokens_to_ids
This commit is contained in:
Wing Lian
2023-07-22 04:10:38 -04:00
committed by GitHub

View File

@@ -48,16 +48,22 @@ class PromptTokenizingStrategy(abc.ABC):
@functools.lru_cache(maxsize=128) @functools.lru_cache(maxsize=128)
def _get_user_token(self): def _get_user_token(self):
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>") try:
if isinstance(id_or_ids, (int,)): id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
return id_or_ids if isinstance(id_or_ids, (int,)):
return id_or_ids
except KeyError:
pass
return False return False
@functools.lru_cache(maxsize=128) @functools.lru_cache(maxsize=128)
def _get_assistant_token(self): def _get_assistant_token(self):
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>") try:
if isinstance(id_or_ids, (int,)): id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
return id_or_ids if isinstance(id_or_ids, (int,)):
return id_or_ids
except KeyError:
pass
return False return False
def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False): def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):