diff --git a/configs/llama_7B_jeopardy.yml b/configs/llama_7B_jeopardy.yml
new file mode 100644
index 000000000..1f0fbf9cf
--- /dev/null
+++ b/configs/llama_7B_jeopardy.yml
@@ -0,0 +1,58 @@
+base_model: huggyllama/llama-7b
+base_model_config: huggyllama/llama-7b
+model_type: LlamaForCausalLM
+tokenizer_type: LlamaTokenizer
+load_in_8bit: false
+datasets:
+ - path: openaccess-ai-collective/jeopardy
+ type: jeopardy
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.01
+adapter:
+lora_model_dir:
+sequence_len: 2048
+max_packed_sequence_len: 2048
+lora_r: 8
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_modules:
+ - q_proj
+ - v_proj
+lora_fan_in_fan_out: false
+wandb_project: jeopardy-bot-7b
+wandb_watch:
+wandb_run_id:
+wandb_log_model: checkpoint
+output_dir: ./jeopardy-bot-7b
+batch_size: 4
+micro_batch_size: 1
+num_epochs: 2
+optimizer: adamw_bnb_8bit
+torchdistx_path:
+lr_scheduler: cosine
+learning_rate: 0.0000002
+train_on_inputs: false
+group_by_length: false
+bf16: true
+tf32: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 5
+xformers_attention: true
+flash_attention:
+gptq_groupsize:
+gptq_model_v1:
+warmup_steps: 20
+eval_steps: 110
+save_steps: 660
+debug:
+deepspeed:
+weight_decay: 0.0001
+fsdp:
+fsdp_config:
+special_tokens:
+ pad_token: "[PAD]"
+ bos_token: ""
+ eos_token: ""
+ unk_token: ""
diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py
index 1909ec289..8bc81d327 100644
--- a/src/axolotl/prompt_tokenizers.py
+++ b/src/axolotl/prompt_tokenizers.py
@@ -89,6 +89,15 @@ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
)
+class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
+ return (
+ prompt["question"],
+ prompt["category"],
+ "what is " + prompt["answer"],
+ )
+
+
class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
return (
diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py
index c2acf60c3..cb3a712b9 100644
--- a/src/axolotl/prompters.py
+++ b/src/axolotl/prompters.py
@@ -31,6 +31,10 @@ class AlpacaPrompter:
return output.split(self.response_split)[1].strip()
+class JeopardyPrompter(AlpacaPrompter):
+ prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
+
+
class GPTeacherPrompter(AlpacaPrompter):
...
diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py
index d315da98c..2c987b4f4 100644
--- a/src/axolotl/utils/data.py
+++ b/src/axolotl/utils/data.py
@@ -11,13 +11,13 @@ from axolotl.prompt_tokenizers import (
GPTeacherPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
AlpacaReflectionPTStrategy,
- ShareGPTPromptTokenizingStrategy,
+ ShareGPTPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy,
)
from axolotl.prompters import (
AlpacaPrompter,
GPTeacherPrompter,
ReflectAlpacaPrompter,
- ShareGPTPrompter,
+ ShareGPTPrompter, JeopardyPrompter,
)
@@ -82,6 +82,12 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper)
+ if d.type == "jeopardy":
+ ds_strategy = JeopardyPromptTokenizingStrategy(
+ JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
+ )
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
+ datasets.append(ds_wrapper)
elif d.type == "oasst":
ds_strategy = OpenAssistantPromptTokenizingStrategy(
AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len