Merge pull request #15 from NanoCode012/feat/completion

Feat: Add Completion dataset type
This commit is contained in:
Wing Lian
2023-05-08 19:04:54 -04:00
committed by GitHub
3 changed files with 45 additions and 2 deletions

View File

@@ -125,6 +125,25 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
)
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str):
return (
prompt["text"]
)
def tokenize_prompt(self, prompt):
instruction = self.parse_instruction_fields(prompt)
full_prompt = self._build_full_prompt(instruction)
tokenized_full_prompt = self._tokenize(full_prompt)
return tokenized_full_prompt
def _build_full_prompt(self, instruction):
return self.prompter.build_prompt(
instruction
)
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
raise NotImplementedError

View File

@@ -35,6 +35,17 @@ class JeopardyPrompter(AlpacaPrompter):
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
class CompletionPrompter(AlpacaPrompter):
def build_prompt(
self,
instruction: str
) -> str:
return instruction
def get_response(self, output: str) -> str:
return output.strip()
class GPTeacherPrompter(AlpacaPrompter):
...

View File

@@ -11,13 +11,17 @@ from axolotl.prompt_tokenizers import (
GPTeacherPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
AlpacaReflectionPTStrategy,
ShareGPTPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy,
CompletionPromptTokenizingStrategy,
)
from axolotl.prompters import (
AlpacaPrompter,
GPTeacherPrompter,
ReflectAlpacaPrompter,
ShareGPTPrompter, JeopardyPrompter,
ShareGPTPrompter,
JeopardyPrompter,
CompletionPrompter,
)
@@ -118,6 +122,15 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper)
elif d.type == "completion":
ds_strategy = CompletionPromptTokenizingStrategy(
CompletionPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper)
else:
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
logging.info("tokenizing, merging, and shuffling master dataset")