Jeopardy bot! (#17)
* support for jeopardy dataset * commit the final config for jeopardy bot
This commit is contained in:
@@ -89,6 +89,15 @@ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
)
|
||||
|
||||
|
||||
class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
return (
|
||||
prompt["question"],
|
||||
prompt["category"],
|
||||
"what is " + prompt["answer"],
|
||||
)
|
||||
|
||||
|
||||
class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
return (
|
||||
|
||||
@@ -31,6 +31,10 @@ class AlpacaPrompter:
|
||||
return output.split(self.response_split)[1].strip()
|
||||
|
||||
|
||||
class JeopardyPrompter(AlpacaPrompter):
|
||||
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||
|
||||
|
||||
class GPTeacherPrompter(AlpacaPrompter):
|
||||
...
|
||||
|
||||
|
||||
@@ -11,13 +11,13 @@ from axolotl.prompt_tokenizers import (
|
||||
GPTeacherPromptTokenizingStrategy,
|
||||
OpenAssistantPromptTokenizingStrategy,
|
||||
AlpacaReflectionPTStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
ShareGPTPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompters import (
|
||||
AlpacaPrompter,
|
||||
GPTeacherPrompter,
|
||||
ReflectAlpacaPrompter,
|
||||
ShareGPTPrompter,
|
||||
ShareGPTPrompter, JeopardyPrompter,
|
||||
)
|
||||
|
||||
|
||||
@@ -82,6 +82,12 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||
datasets.append(ds_wrapper)
|
||||
if d.type == "jeopardy":
|
||||
ds_strategy = JeopardyPromptTokenizingStrategy(
|
||||
JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||
datasets.append(ds_wrapper)
|
||||
elif d.type == "oasst":
|
||||
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
||||
AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||
|
||||
Reference in New Issue
Block a user