diff --git a/configs/llama_7B_jeopardy.yml b/configs/llama_7B_jeopardy.yml new file mode 100644 index 000000000..1f0fbf9cf --- /dev/null +++ b/configs/llama_7B_jeopardy.yml @@ -0,0 +1,58 @@ +base_model: huggyllama/llama-7b +base_model_config: huggyllama/llama-7b +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +load_in_8bit: false +datasets: + - path: openaccess-ai-collective/jeopardy + type: jeopardy +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +adapter: +lora_model_dir: +sequence_len: 2048 +max_packed_sequence_len: 2048 +lora_r: 8 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: + - q_proj + - v_proj +lora_fan_in_fan_out: false +wandb_project: jeopardy-bot-7b +wandb_watch: +wandb_run_id: +wandb_log_model: checkpoint +output_dir: ./jeopardy-bot-7b +batch_size: 4 +micro_batch_size: 1 +num_epochs: 2 +optimizer: adamw_bnb_8bit +torchdistx_path: +lr_scheduler: cosine +learning_rate: 0.0000002 +train_on_inputs: false +group_by_length: false +bf16: true +tf32: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 5 +xformers_attention: true +flash_attention: +gptq_groupsize: +gptq_model_v1: +warmup_steps: 20 +eval_steps: 110 +save_steps: 660 +debug: +deepspeed: +weight_decay: 0.0001 +fsdp: +fsdp_config: +special_tokens: + pad_token: "[PAD]" + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 1909ec289..8bc81d327 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -89,6 +89,15 @@ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): ) +class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + def parse_instruction_fields(self, prompt) -> (str, str, str): + return ( + prompt["question"], + prompt["category"], + "what is " + prompt["answer"], + ) + + class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): def parse_instruction_fields(self, prompt) -> (str, str, str): return ( diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index c2acf60c3..cb3a712b9 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -31,6 +31,10 @@ class AlpacaPrompter: return output.split(self.response_split)[1].strip() +class JeopardyPrompter(AlpacaPrompter): + prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" + + class GPTeacherPrompter(AlpacaPrompter): ... diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index d315da98c..2c987b4f4 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -11,13 +11,13 @@ from axolotl.prompt_tokenizers import ( GPTeacherPromptTokenizingStrategy, OpenAssistantPromptTokenizingStrategy, AlpacaReflectionPTStrategy, - ShareGPTPromptTokenizingStrategy, + ShareGPTPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy, ) from axolotl.prompters import ( AlpacaPrompter, GPTeacherPrompter, ReflectAlpacaPrompter, - ShareGPTPrompter, + ShareGPTPrompter, JeopardyPrompter, ) @@ -82,6 +82,12 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) + if d.type == "jeopardy": + ds_strategy = JeopardyPromptTokenizingStrategy( + JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + datasets.append(ds_wrapper) elif d.type == "oasst": ds_strategy = OpenAssistantPromptTokenizingStrategy( AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len