diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 8bc81d327..167648618 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -125,6 +125,25 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): ) +class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + def parse_instruction_fields(self, prompt) -> (str): + return ( + prompt["text"] + ) + + def tokenize_prompt(self, prompt): + instruction = self.parse_instruction_fields(prompt) + full_prompt = self._build_full_prompt(instruction) + tokenized_full_prompt = self._tokenize(full_prompt) + + return tokenized_full_prompt + + def _build_full_prompt(self, instruction): + return self.prompter.build_prompt( + instruction + ) + + class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): def parse_instruction_fields(self, prompt) -> (str, str, str, str, str): raise NotImplementedError diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index cb3a712b9..914cbd0de 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -35,6 +35,17 @@ class JeopardyPrompter(AlpacaPrompter): prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" +class CompletionPrompter(AlpacaPrompter): + def build_prompt( + self, + instruction: str + ) -> str: + return instruction + + def get_response(self, output: str) -> str: + return output.strip() + + class GPTeacherPrompter(AlpacaPrompter): ... diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 2c987b4f4..a6cc6cdcd 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -11,13 +11,17 @@ from axolotl.prompt_tokenizers import ( GPTeacherPromptTokenizingStrategy, OpenAssistantPromptTokenizingStrategy, AlpacaReflectionPTStrategy, - ShareGPTPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy, + ShareGPTPromptTokenizingStrategy, + JeopardyPromptTokenizingStrategy, + CompletionPromptTokenizingStrategy, ) from axolotl.prompters import ( AlpacaPrompter, GPTeacherPrompter, ReflectAlpacaPrompter, - ShareGPTPrompter, JeopardyPrompter, + ShareGPTPrompter, + JeopardyPrompter, + CompletionPrompter, ) @@ -118,6 +122,15 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) + elif d.type == "completion": + ds_strategy = CompletionPromptTokenizingStrategy( + CompletionPrompter(), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + datasets.append(ds_wrapper) else: logging.error(f"unhandled prompt tokenization strategy: {d.type}") logging.info("tokenizing, merging, and shuffling master dataset")