From 5d86137f70f23ea5c1663191ae260510bdd331db Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 14:20:11 +0900 Subject: [PATCH] Lint prompt_tokenizers --- src/axolotl/prompt_tokenizers.py | 111 +++++++++++++++++++++++++------ 1 file changed, 89 insertions(+), 22 deletions(-) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index a91a4e2d3..7febd0a72 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -1,7 +1,10 @@ +"""Module containing PromptTokenizingStrategy and Prompter classes""" + import abc import copy import functools import logging +from typing import Tuple from transformers import PreTrainedTokenizer @@ -15,10 +18,16 @@ LLAMA_DEFAULT_UNK_TOKEN = "" class InvalidDataException(Exception): - pass + """ + Exception raised when the data is invalid + """ class PromptTokenizingStrategy(abc.ABC): + """ + Abstract class for tokenizing strategies + """ + def __init__( self, prompter, @@ -35,14 +44,14 @@ class PromptTokenizingStrategy(abc.ABC): def tokenize_prompt(self, prompt): pass - @functools.cache + @functools.lru_cache(maxsize=128) def _get_user_token(self): id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>") if isinstance(id_or_ids, (int,)): return id_or_ids return False - @functools.cache + @functools.lru_cache(maxsize=128) def _get_assistant_token(self): id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>") if isinstance(id_or_ids, (int,)): @@ -51,11 +60,19 @@ class PromptTokenizingStrategy(abc.ABC): class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for instruction-based prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: raise NotImplementedError def tokenize_prompt(self, prompt): - instruction, input, response = self.parse_instruction_fields(prompt) + ( + instruction, + input, # pylint: disable=redefined-builtin + response, + ) = self.parse_instruction_fields(prompt) full_prompt = self._build_full_prompt(instruction, input, response) tokenized_full_prompt = self._tokenize(full_prompt) if not self.train_on_inputs: @@ -76,7 +93,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): return tokenized_full_prompt - def _build_full_prompt(self, instruction, input, response): + def _build_full_prompt( + self, instruction, input, response # pylint: disable=redefined-builtin + ): return next( iter( self.prompter.build_prompt( @@ -112,7 +131,11 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for Alpaca prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: return ( prompt["instruction"], prompt["input"] if "input" in prompt else "", @@ -121,7 +144,11 @@ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for Alpaca Multiple Choice prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: return ( prompt["question"], "\n".join(f'- "{choice}"' for choice in prompt["choices"]), @@ -130,7 +157,11 @@ class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingSt class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for Jeopardy prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: return ( prompt["question"], prompt["category"], @@ -139,7 +170,11 @@ class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for OpenAssistant prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: return ( prompt["INSTRUCTION"], "", @@ -148,7 +183,11 @@ class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy) class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for SummarizeTLDR prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: return ( prompt["article"], "", @@ -157,7 +196,11 @@ class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy) class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for GPTeacher prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: return ( prompt["instruction"], prompt["input"] if "input" in prompt else "", @@ -166,7 +209,11 @@ class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for NomicGPT4All prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: return ( prompt["prompt"], "", @@ -175,6 +222,10 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + """ + Tokenizing strategy for Completion prompts. + """ + def parse_instruction_fields(self, prompt) -> str: return prompt["text"] @@ -185,18 +236,24 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): return tokenized_full_prompt - def _build_full_prompt(self, instruction, input, response): + def _build_full_prompt( + self, instruction, input, response + ): # pylint: disable=unused-argument, redefined-builtin return next(iter(self.prompter.build_prompt(instruction))) class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str, str, str): + """ + Tokenizing strategy for Reflection prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]: raise NotImplementedError def tokenize_prompt(self, prompt): ( instruction, - input, + input, # pylint: disable=redefined-builtin output, reflection, corrected, @@ -223,7 +280,9 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): return tokenized_full_prompt - def _build_full_prompt(self, instruction, input, output, reflection, corrected): + def _build_full_prompt( + self, instruction, input, output, reflection, corrected + ): # pylint: disable=redefined-builtin return next( iter( self.prompter.build_prompt( @@ -257,7 +316,11 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str, str, str): + """ + Tokenizing strategy for Alpaca Reflection prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]: return ( prompt["instruction"], prompt["input"] if "input" in prompt else "", @@ -268,6 +331,10 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy): class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): + """ + Tokenizing strategy for ShareGPT prompts. + """ + def get_conversation_thread(self, prompt): return prompt["conversations"] @@ -281,7 +348,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): user_token = self._get_user_token() assistant_token = self._get_assistant_token() try: - for i, part in enumerate( + for _, part in enumerate( self.prompter.build_prompt(self.get_conversation_thread(prompt)) ): if isinstance(part, tuple): @@ -307,7 +374,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): # not masked out from labels labels = copy.deepcopy(res["input_ids"]) else: - logging.warning("unhandled role: " + part[0]) + logging.warning(f"unhandled role: {part[0]}") else: # this is only ever the first part, should include the bos token and the user query res = self._tokenize( @@ -324,8 +391,8 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): result["labels"][current_len : current_len + input_len] = labels current_len += input_len return result - except (KeyError, AssertionError, IndexError) as e: - raise InvalidDataException(str(e)) + except (KeyError, AssertionError, IndexError) as err: + raise InvalidDataException(str(err)) from err def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): result = self.tokenizer(