Merge pull request #15 from NanoCode012/feat/completion
Feat: Add Completion dataset type
This commit is contained in:
@@ -125,6 +125,25 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
)
|
||||
|
||||
|
||||
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str):
|
||||
return (
|
||||
prompt["text"]
|
||||
)
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
instruction = self.parse_instruction_fields(prompt)
|
||||
full_prompt = self._build_full_prompt(instruction)
|
||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||
|
||||
return tokenized_full_prompt
|
||||
|
||||
def _build_full_prompt(self, instruction):
|
||||
return self.prompter.build_prompt(
|
||||
instruction
|
||||
)
|
||||
|
||||
|
||||
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -35,6 +35,17 @@ class JeopardyPrompter(AlpacaPrompter):
|
||||
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||
|
||||
|
||||
class CompletionPrompter(AlpacaPrompter):
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str
|
||||
) -> str:
|
||||
return instruction
|
||||
|
||||
def get_response(self, output: str) -> str:
|
||||
return output.strip()
|
||||
|
||||
|
||||
class GPTeacherPrompter(AlpacaPrompter):
|
||||
...
|
||||
|
||||
|
||||
@@ -11,13 +11,17 @@ from axolotl.prompt_tokenizers import (
|
||||
GPTeacherPromptTokenizingStrategy,
|
||||
OpenAssistantPromptTokenizingStrategy,
|
||||
AlpacaReflectionPTStrategy,
|
||||
ShareGPTPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
JeopardyPromptTokenizingStrategy,
|
||||
CompletionPromptTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompters import (
|
||||
AlpacaPrompter,
|
||||
GPTeacherPrompter,
|
||||
ReflectAlpacaPrompter,
|
||||
ShareGPTPrompter, JeopardyPrompter,
|
||||
ShareGPTPrompter,
|
||||
JeopardyPrompter,
|
||||
CompletionPrompter,
|
||||
)
|
||||
|
||||
|
||||
@@ -118,6 +122,15 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||
datasets.append(ds_wrapper)
|
||||
elif d.type == "completion":
|
||||
ds_strategy = CompletionPromptTokenizingStrategy(
|
||||
CompletionPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||
datasets.append(ds_wrapper)
|
||||
else:
|
||||
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
||||
logging.info("tokenizing, merging, and shuffling master dataset")
|
||||
|
||||
Reference in New Issue
Block a user