From cf681537ecc401b769b751d42a13535742e2b237 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 9 May 2023 00:30:36 +0900 Subject: [PATCH 1/2] Add CompletionPrompt type --- src/axolotl/prompt_tokenizers.py | 19 +++++++++++++++++++ src/axolotl/prompters.py | 11 +++++++++++ src/axolotl/utils/data.py | 17 +++++++++++++++-- 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 8bc81d327..16dd7eea4 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -125,6 +125,25 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): ) +class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + def parse_instruction_fields(self, prompt) -> (str): + return ( + prompt["text"] + ) + + def tokenize_prompt(self, prompt): + text = self.parse_instruction_fields(prompt) + full_prompt = self._build_full_prompt(text) + tokenized_full_prompt = self._tokenize(full_prompt) + + return tokenized_full_prompt + + def _build_full_prompt(self, text): + return self.prompter.build_prompt( + text + ) + + class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): def parse_instruction_fields(self, prompt) -> (str, str, str, str, str): raise NotImplementedError diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index cb3a712b9..920a50b78 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -35,6 +35,17 @@ class JeopardyPrompter(AlpacaPrompter): prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" +class CompletionPrompter(AlpacaPrompter): + def build_prompt( + self, + text: str + ) -> str: + return text + + def get_response(self, output: str) -> str: + return output.strip() + + class GPTeacherPrompter(AlpacaPrompter): ... diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 2c987b4f4..a6cc6cdcd 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -11,13 +11,17 @@ from axolotl.prompt_tokenizers import ( GPTeacherPromptTokenizingStrategy, OpenAssistantPromptTokenizingStrategy, AlpacaReflectionPTStrategy, - ShareGPTPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy, + ShareGPTPromptTokenizingStrategy, + JeopardyPromptTokenizingStrategy, + CompletionPromptTokenizingStrategy, ) from axolotl.prompters import ( AlpacaPrompter, GPTeacherPrompter, ReflectAlpacaPrompter, - ShareGPTPrompter, JeopardyPrompter, + ShareGPTPrompter, + JeopardyPrompter, + CompletionPrompter, ) @@ -118,6 +122,15 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) + elif d.type == "completion": + ds_strategy = CompletionPromptTokenizingStrategy( + CompletionPrompter(), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + datasets.append(ds_wrapper) else: logging.error(f"unhandled prompt tokenization strategy: {d.type}") logging.info("tokenizing, merging, and shuffling master dataset") From 174b74ddc99acd2b541438826ad871ec16f740a9 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 9 May 2023 00:54:46 +0900 Subject: [PATCH 2/2] Rename variable to use same convention --- src/axolotl/prompt_tokenizers.py | 8 ++++---- src/axolotl/prompters.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 16dd7eea4..167648618 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -132,15 +132,15 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): ) def tokenize_prompt(self, prompt): - text = self.parse_instruction_fields(prompt) - full_prompt = self._build_full_prompt(text) + instruction = self.parse_instruction_fields(prompt) + full_prompt = self._build_full_prompt(instruction) tokenized_full_prompt = self._tokenize(full_prompt) return tokenized_full_prompt - def _build_full_prompt(self, text): + def _build_full_prompt(self, instruction): return self.prompter.build_prompt( - text + instruction ) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 920a50b78..914cbd0de 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -38,9 +38,9 @@ class JeopardyPrompter(AlpacaPrompter): class CompletionPrompter(AlpacaPrompter): def build_prompt( self, - text: str + instruction: str ) -> str: - return text + return instruction def get_response(self, output: str) -> str: return output.strip()