Merge pull request #307 from OpenAccess-AI-Collective/xgen-user-sharegpt-tokens
better handling since xgen tokenizer breaks with convert_tokens_to_ids
This commit is contained in:
@@ -48,16 +48,22 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
|
|
||||||
@functools.lru_cache(maxsize=128)
|
@functools.lru_cache(maxsize=128)
|
||||||
def _get_user_token(self):
|
def _get_user_token(self):
|
||||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
|
try:
|
||||||
if isinstance(id_or_ids, (int,)):
|
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
|
||||||
return id_or_ids
|
if isinstance(id_or_ids, (int,)):
|
||||||
|
return id_or_ids
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=128)
|
@functools.lru_cache(maxsize=128)
|
||||||
def _get_assistant_token(self):
|
def _get_assistant_token(self):
|
||||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
|
try:
|
||||||
if isinstance(id_or_ids, (int,)):
|
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
|
||||||
return id_or_ids
|
if isinstance(id_or_ids, (int,)):
|
||||||
|
return id_or_ids
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):
|
def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):
|
||||||
|
|||||||
Reference in New Issue
Block a user