Merge pull request #307 from OpenAccess-AI-Collective/xgen-user-sharegpt-tokens
better handling since xgen tokenizer breaks with convert_tokens_to_ids
This commit is contained in:
@@ -48,16 +48,22 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_user_token(self):
|
||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
|
||||
if isinstance(id_or_ids, (int,)):
|
||||
return id_or_ids
|
||||
try:
|
||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
|
||||
if isinstance(id_or_ids, (int,)):
|
||||
return id_or_ids
|
||||
except KeyError:
|
||||
pass
|
||||
return False
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_assistant_token(self):
|
||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
|
||||
if isinstance(id_or_ids, (int,)):
|
||||
return id_or_ids
|
||||
try:
|
||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
|
||||
if isinstance(id_or_ids, (int,)):
|
||||
return id_or_ids
|
||||
except KeyError:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):
|
||||
|
||||
Reference in New Issue
Block a user