From 5e3714475439523568840784970728e963b95374 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 15 May 2023 22:15:36 -0400 Subject: [PATCH] fix prompters, especially the sharegpt prompter --- src/axolotl/prompt_tokenizers.py | 84 +++++++++++++++++++++++++----- src/axolotl/prompters.py | 88 ++++++-------------------------- 2 files changed, 89 insertions(+), 83 deletions(-) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 00d8ecbf9..5792d191b 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -1,7 +1,10 @@ import abc +import copy from transformers import PreTrainedTokenizer +from axolotl.prompters import IGNORE_TOKEN_ID + IGNORE_INDEX = -100 LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" LLAMA_DEFAULT_EOS_TOKEN = "" @@ -40,10 +43,10 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): full_prompt = self._build_full_prompt(instruction, input, response) tokenized_full_prompt = self._tokenize(full_prompt) if not self.train_on_inputs: - user_prompt = self.prompter.build_prompt( + user_prompt = next(iter(self.prompter.build_prompt( instruction, input, - ) + ))) tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False) user_prompt_len = len(tokenized_user_prompt["input_ids"]) # TODO this could be sped up using numpy array slicing @@ -54,11 +57,11 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): return tokenized_full_prompt def _build_full_prompt(self, instruction, input, response): - return self.prompter.build_prompt( + return next(iter(self.prompter.build_prompt( instruction, input, response, - ) + ))) def _tokenize(self, prompt, add_eos_token=True): result = self.tokenizer( @@ -131,13 +134,13 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): def tokenize_prompt(self, prompt): instruction = self.parse_instruction_fields(prompt) - full_prompt = self._build_full_prompt(instruction) + full_prompt = self._build_full_prompt(instruction, None, None) tokenized_full_prompt = self._tokenize(full_prompt) return tokenized_full_prompt - def _build_full_prompt(self, instruction): - return self.prompter.build_prompt(instruction) + def _build_full_prompt(self, instruction, input, response): + return next(iter(self.prompter.build_prompt(instruction))) class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): @@ -157,10 +160,10 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): ) tokenized_full_prompt = self._tokenize(full_prompt) if not self.train_on_inputs: - user_prompt = self.prompter.build_prompt( + user_prompt = next(iter(self.prompter.build_prompt( instruction, input, - ) + ))) tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False) user_prompt_len = len(tokenized_user_prompt["input_ids"]) # TODO this could be sped up using numpy array slicing @@ -171,13 +174,13 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): return tokenized_full_prompt def _build_full_prompt(self, instruction, input, output, reflection, corrected): - return self.prompter.build_prompt( + return next(iter(self.prompter.build_prompt( instruction, input, output, reflection, corrected, - ) + ))) def _tokenize(self, prompt, add_eos_token=True): result = self.tokenizer( @@ -212,7 +215,64 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy): class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): def tokenize_prompt(self, prompt): + result = { + "input_ids": [], + "attention_mask": [], + "labels": [], + } + current_len = 0 try: - return self.prompter.build_prompt(prompt["conversations"], self.tokenizer) + for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"], self.tokenizer)): + if i == 0: + # this is only ever the first part, should include the bos token and the user query + res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=False) + # everything from this is masked out from the labels + labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"]) + elif i % 2 == 0: + # this is still the user query, we should + res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=True) + # everything from this is masked out from the labels + labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"]) + else: + # this should be the assistent response, should end with an eos token + res = self._tokenize(part.strip(), add_eos_token=True, strip_bos_token=True) + # not masked out from labels + labels = copy.deepcopy(res["input_ids"]) + input_ids = res["input_ids"] + input_len = len(input_ids) + result["input_ids"][current_len : current_len + input_len] = input_ids + result["attention_mask"][current_len : current_len + input_len] = [ + 1 if x != self.tokenizer.pad_token_id else 0 + for x in input_ids + ] + result["labels"][current_len : current_len + input_len] = labels + current_len += input_len + return result except (KeyError, AssertionError, IndexError) as e: raise InvalidDataException(str(e)) + + def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): + result = self.tokenizer( + prompt, + truncation=True, + max_length=self.sequence_len, + padding=False, + return_tensors=None, + ) + if ( + result["input_ids"][-1] != self.tokenizer.eos_token_id + and len(result["input_ids"]) < self.sequence_len + and add_eos_token + ): + result["input_ids"].append(self.tokenizer.eos_token_id) + result["attention_mask"].append(1) + + if ( + result["input_ids"][0] == self.tokenizer.bos_token_id + and strip_bos_token + ): + result["input_ids"] = result["input_ids"][1:] + result["attention_mask"] = result["attention_mask"][1:] + + result["labels"] = result["input_ids"].copy() + return result diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 3dc5d6433..a52ed4ad9 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -1,7 +1,7 @@ import copy import dataclasses from enum import auto, Enum -from typing import List, Tuple, Any, Union +from typing import List, Tuple, Any, Union, Generator IGNORE_TOKEN_ID = -100 @@ -16,7 +16,7 @@ class AlpacaPrompter: instruction: str, input: Union[None, str] = None, output: Union[None, str] = None, - ) -> str: + ) -> Generator[str, None, None]: # returns the full prompt from instruction and optional input # if a label (=response, =output) is provided, it's also appended. if input: @@ -25,7 +25,7 @@ class AlpacaPrompter: res = self.prompt_no_input.format(instruction=instruction) if output: res = f"{res}{output}" - return res + yield res def get_response(self, output: str) -> str: return output.split(self.response_split)[1].strip() @@ -36,8 +36,8 @@ class JeopardyPrompter(AlpacaPrompter): class CompletionPrompter(AlpacaPrompter): - def build_prompt(self, instruction: str) -> str: - return instruction + def build_prompt(self, instruction: str, input=None, output=None) -> Generator[str, None, None]: + yield instruction def get_response(self, output: str) -> str: return output.strip() @@ -64,7 +64,7 @@ class ReflectAlpacaPrompter: output: Union[None, str] = None, reflection: Union[None, str] = None, corrected: Union[None, str] = None, - ) -> str: + ) -> Generator[str, None, None]: # returns the full prompt from instruction and optional input # if a label (=response, =output) is provided, it's also appended. if input: @@ -76,7 +76,7 @@ class ReflectAlpacaPrompter: output=output, reflection=reflection, corrected=corrected ) res = f"{res}{label}" - return res + yield res def get_response(self, output: str) -> str: return output.split(self.response_split)[1].strip() @@ -103,15 +103,16 @@ class Conversation: sep: str = "###" sep2: str = None - def get_prompt(self): + def get_prompt(self) -> Generator[str, None, None]: seps = [self.sep, self.sep2] - ret = self.system + seps[0] + preamble = self.system + seps[0] for i, (role, message) in enumerate(self.messages): if message: - ret += role + ": " + message + seps[i % 2] + yield preamble + role + ": " + message + seps[i % 2] else: - ret += role + ":" - return ret + yield role + ":" + if i == 0: + preamble = "" def copy(self): return Conversation( @@ -136,12 +137,12 @@ conv_vicuna_v1_1 = Conversation( offset=0, sep_style=SeparatorStyle.TWO, sep=" ", - sep2="", + sep2=" ", ) class ShareGPTPrompter: - def build_prompt(self, source, tokenizer, sequence_len=2048): + def build_prompt(self, source, tokenizer, sequence_len=2048) -> Generator[str, None, None]: # ignore the system prompt if provided if source[0]["from"] == "system": source.pop(0) @@ -171,61 +172,6 @@ class ShareGPTPrompter: role = roles[sentence["from"]] assert role == conv.roles[j % 2] conv.append_message(role, sentence["value"]) - # TODO, this concatenates everything, but doesn't seem to properly add the eos_token_id, as the eos_token gets split up - conversation = conv.get_prompt() - # Tokenize conversations - tokenized_result = tokenizer( - conversation, - truncation=True, - max_length=sequence_len, # FIXME - padding=False, - return_tensors=None, - ) - target = copy.deepcopy(tokenized_result["input_ids"]) - - # Mask targets - sep = conv.sep + conv.roles[1] + ": " - - rounds = conversation.split(conv.sep2) - rounds = [r + conv.sep2 for r in rounds] - cur_len = 1 - target[0] = IGNORE_TOKEN_ID # mask out the bos - for i, rou in enumerate(rounds): - if rou == "": - break - - parts = rou.split(sep) - if len(parts) != 2: - break - parts[0] += sep - round_len = ( - len(tokenizer(rou)["input_ids"]) - 1 - ) # -1 ignores the bos_token generated for this - # we have to strip the initial part, any dangling whitespace creates an additional ghost token - instruction_len = ( - len(tokenizer(parts[0].strip())["input_ids"]) - 1 - ) # -1 ignores the bos_token generated for this - target[cur_len : cur_len + instruction_len] = [ - IGNORE_TOKEN_ID - ] * instruction_len - - cur_len += round_len - if cur_len >= sequence_len: - break - - # Fix: Truncate the target to have the same length as input_ids - target = target[: len(tokenized_result["input_ids"])] - # target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len) - - attention_mask = [ - 1 if x != tokenizer.pad_token_id else 0 - for x in tokenized_result["input_ids"] - ] - - # TODO truncate len to sequence_len - return dict( - input_ids=tokenized_result["input_ids"], - labels=target, - attention_mask=attention_mask, - ) + for part in conv.get_prompt(): + yield part