From e91fed495a8258406e53f3d98d63b241c8143763 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 23 Jun 2023 15:47:40 -0400 Subject: [PATCH] better handling for tokenizers like flan that don't have a bos token --- src/axolotl/prompt_tokenizers.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 6408620d7..9d9be8b66 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -73,8 +73,17 @@ class PromptTokenizingStrategy(abc.ABC): ): result["input_ids"].append(self.tokenizer.eos_token_id) result["attention_mask"].append(1) + elif ( # some tokenizers automatically add an eos token, let's remove it + not add_eos_token and result["input_ids"][-1] == self.tokenizer.eos_token_id + ): + result["input_ids"] = result["input_ids"][:-1] + result["attention_mask"] = result["attention_mask"][:-1] - if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: + if ( + self.tokenizer.bos_token_id + and result["input_ids"][0] == self.tokenizer.bos_token_id + and strip_bos_token + ): result["input_ids"] = result["input_ids"][1:] result["attention_mask"] = result["attention_mask"][1:] @@ -412,7 +421,11 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): result["input_ids"].append(self.tokenizer.eos_token_id) result["attention_mask"].append(1) - if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: + if ( + self.tokenizer.bos_token_id + and result["input_ids"][0] == self.tokenizer.bos_token_id + and strip_bos_token + ): result["input_ids"] = result["input_ids"][1:] result["attention_mask"] = result["attention_mask"][1:]