Compare commits
1 Commits
coderabbit
...
flan-no-bo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e91fed495a |
@@ -73,8 +73,17 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
):
|
):
|
||||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||||
result["attention_mask"].append(1)
|
result["attention_mask"].append(1)
|
||||||
|
elif ( # some tokenizers automatically add an eos token, let's remove it
|
||||||
|
not add_eos_token and result["input_ids"][-1] == self.tokenizer.eos_token_id
|
||||||
|
):
|
||||||
|
result["input_ids"] = result["input_ids"][:-1]
|
||||||
|
result["attention_mask"] = result["attention_mask"][:-1]
|
||||||
|
|
||||||
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
if (
|
||||||
|
self.tokenizer.bos_token_id
|
||||||
|
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
||||||
|
and strip_bos_token
|
||||||
|
):
|
||||||
result["input_ids"] = result["input_ids"][1:]
|
result["input_ids"] = result["input_ids"][1:]
|
||||||
result["attention_mask"] = result["attention_mask"][1:]
|
result["attention_mask"] = result["attention_mask"][1:]
|
||||||
|
|
||||||
@@ -412,7 +421,11 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||||
result["attention_mask"].append(1)
|
result["attention_mask"].append(1)
|
||||||
|
|
||||||
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
if (
|
||||||
|
self.tokenizer.bos_token_id
|
||||||
|
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
||||||
|
and strip_bos_token
|
||||||
|
):
|
||||||
result["input_ids"] = result["input_ids"][1:]
|
result["input_ids"] = result["input_ids"][1:]
|
||||||
result["attention_mask"] = result["attention_mask"][1:]
|
result["attention_mask"] = result["attention_mask"][1:]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user