From 8e46c0fb0ddd1dbeb4f31a542ae18d873563192f Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 15:08:26 +0900 Subject: [PATCH] Refactor duplicate code between Prompter and Pygmalion --- src/axolotl/prompt_strategies/pygmalion.py | 51 +++------- src/axolotl/prompt_tokenizers.py | 111 +++++++++++++-------- 2 files changed, 86 insertions(+), 76 deletions(-) diff --git a/src/axolotl/prompt_strategies/pygmalion.py b/src/axolotl/prompt_strategies/pygmalion.py index 01828a034..4cd9a1685 100644 --- a/src/axolotl/prompt_strategies/pygmalion.py +++ b/src/axolotl/prompt_strategies/pygmalion.py @@ -5,7 +5,11 @@ import logging from collections import defaultdict from typing import Generator -from axolotl.prompt_tokenizers import PromptTokenizingStrategy +from axolotl.prompt_tokenizers import ( + PromptTokenizingStrategy, + parse_tokenized_to_result, + tokenize_prompt_default, +) IGNORE_TOKEN_ID = -100 @@ -23,12 +27,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): self.bot_prefix_token_ids = res["input_ids"] def tokenize_prompt(self, prompt): - result = { - "input_ids": [], - "attention_mask": [], - "labels": [], - } - current_len = 0 + result, current_len = tokenize_prompt_default() for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])): role, message = part if role == "system": @@ -67,37 +66,15 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): else: logging.warning(f"unknown role in conversation: {role}") res = defaultdict(lambda: []) - input_ids = res["input_ids"] - input_len = len(input_ids) - result["input_ids"][current_len : current_len + input_len] = input_ids - result["attention_mask"][current_len : current_len + input_len] = [ - 1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids - ] - result["labels"][current_len : current_len + input_len] = labels - current_len += input_len - return result - def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): - result = self.tokenizer( - prompt, - truncation=True, - max_length=self.sequence_len, - padding=False, - return_tensors=None, - ) - if ( - result["input_ids"][-1] != self.tokenizer.eos_token_id - and len(result["input_ids"]) < self.sequence_len - and add_eos_token - ): - result["input_ids"].append(self.tokenizer.eos_token_id) - result["attention_mask"].append(1) - - if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: - result["input_ids"] = result["input_ids"][1:] - result["attention_mask"] = result["attention_mask"][1:] - - result["labels"] = result["input_ids"].copy() + # pylint: disable=duplicate-code + result, current_len = parse_tokenized_to_result( + result, + current_len, + res, + labels, + pad_token_id=self.tokenizer.pad_token_id, + ) return result diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 7febd0a72..ceb65e2ab 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -4,7 +4,7 @@ import abc import copy import functools import logging -from typing import Tuple +from typing import Dict, List, Tuple from transformers import PreTrainedTokenizer @@ -58,6 +58,29 @@ class PromptTokenizingStrategy(abc.ABC): return id_or_ids return False + def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False): + result = self.tokenizer( + prompt, + truncation=True, + max_length=self.sequence_len, + padding=False, + return_tensors=None, + ) + if ( + result["input_ids"][-1] != self.tokenizer.eos_token_id + and len(result["input_ids"]) < self.sequence_len + and add_eos_token + ): + result["input_ids"].append(self.tokenizer.eos_token_id) + result["attention_mask"].append(1) + + if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: + result["input_ids"] = result["input_ids"][1:] + result["attention_mask"] = result["attention_mask"][1:] + + result["labels"] = result["input_ids"].copy() + return result + class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): """ @@ -106,29 +129,6 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): ) ) - def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): - result = self.tokenizer( - prompt, - truncation=True, - max_length=self.sequence_len, - padding=False, - return_tensors=None, - ) - if ( - result["input_ids"][-1] != self.tokenizer.eos_token_id - and len(result["input_ids"]) < self.sequence_len - and add_eos_token - ): - result["input_ids"].append(self.tokenizer.eos_token_id) - result["attention_mask"].append(1) - - if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: - result["input_ids"] = result["input_ids"][1:] - result["attention_mask"] = result["attention_mask"][1:] - - result["labels"] = result["input_ids"].copy() - return result - class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): """ @@ -295,7 +295,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): ) ) - def _tokenize(self, prompt, add_eos_token=True): + def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): result = self.tokenizer( prompt, truncation=True, @@ -339,12 +339,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): return prompt["conversations"] def tokenize_prompt(self, prompt): - result = { - "input_ids": [], - "attention_mask": [], - "labels": [], - } - current_len = 0 + result, current_len = tokenize_prompt_default() user_token = self._get_user_token() assistant_token = self._get_assistant_token() try: @@ -382,14 +377,15 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): ) # everything from this is masked out from the labels labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - input_ids = res["input_ids"] - input_len = len(input_ids) - result["input_ids"][current_len : current_len + input_len] = input_ids - result["attention_mask"][current_len : current_len + input_len] = [ - 1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids - ] - result["labels"][current_len : current_len + input_len] = labels - current_len += input_len + + # pylint: disable=duplicate-code + result, current_len = parse_tokenized_to_result( + result, + current_len, + res, + labels, + pad_token_id=self.tokenizer.pad_token_id, + ) return result except (KeyError, AssertionError, IndexError) as err: raise InvalidDataException(str(err)) from err @@ -416,3 +412,40 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): result["labels"] = result["input_ids"].copy() return result + + +def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]: + """ + Returns the default values for the tokenize prompt function + """ + + result = { + "input_ids": [], + "attention_mask": [], + "labels": [], + } + current_len = 0 + return result, current_len + + +def parse_tokenized_to_result( + result: Dict[str, List[int]], + current_len: int, + res: Dict[str, List[int]], + labels: list[int], + pad_token_id: int | None = None, +) -> Tuple[Dict[str, List[int]], int]: + """ + Parses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result + """ + + input_ids = res["input_ids"] + input_len = len(input_ids) + result["input_ids"][current_len : current_len + input_len] = input_ids + result["attention_mask"][current_len : current_len + input_len] = [ + 1 if x != pad_token_id else 0 for x in input_ids + ] + result["labels"][current_len : current_len + input_len] = labels + current_len += input_len + + return result, current_len