From 4ac9e251b7572f3abd0adf9e4c12745c820bdd6b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 5 Jun 2023 22:41:00 -0400 Subject: [PATCH] new prompters, misc fixes for output dir missing using fsdp, and changing max seq len --- scripts/finetune.py | 3 + src/axolotl/prompt_strategies/alpaca_chat.py | 40 ++++++++++++ src/axolotl/prompt_strategies/context_qa.py | 67 ++++++++++++++++++++ src/axolotl/utils/models.py | 4 ++ 4 files changed, 114 insertions(+) create mode 100644 src/axolotl/prompt_strategies/context_qa.py diff --git a/scripts/finetune.py b/scripts/finetune.py index 9a2d62904..7c4d865fa 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -279,6 +279,9 @@ def train( logging.info( f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}" ) + + if not Path(cfg.output_dir).is_dir(): + os.makedirs(cfg.output_dir, exist_ok=True) trainer.train(resume_from_checkpoint=resume_from_checkpoint) logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py index 15dfb65c4..ae2f56fa1 100644 --- a/src/axolotl/prompt_strategies/alpaca_chat.py +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -18,6 +18,15 @@ def load(tokenizer, cfg): ) +class AlpacaConcisePrompter(AlpacaPrompter): + """ + Alpaca Prompter extending the system prompt to ask for concise answers + """ + + system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that concisely and appropriately completes the request.\n\n" + system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately and concisely completes the request.\n\n" + + class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): """ Tokenizing strategy for AlpacaQA @@ -31,6 +40,28 @@ class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): ) +class CamelAIPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + """ + Tokenizing strategy for CamelAI datasets + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: + return ( + prompt["message_1"], + "", + prompt["message_1"], + ) + + +def load_concise(tokenizer, cfg): + return AlpacaPromptTokenizingStrategy( + AlpacaConcisePrompter(PromptStyle.CHAT.value), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + + def load_qa(tokenizer, cfg): return AlpacaQAPromptTokenizingStrategy( AlpacaPrompter(PromptStyle.CHAT.value), @@ -38,3 +69,12 @@ def load_qa(tokenizer, cfg): cfg.train_on_inputs, cfg.sequence_len, ) + + +def load_camel_ai(tokenizer, cfg): + return CamelAIPromptTokenizingStrategy( + AlpacaPrompter(PromptStyle.CHAT.value), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) diff --git a/src/axolotl/prompt_strategies/context_qa.py b/src/axolotl/prompt_strategies/context_qa.py new file mode 100644 index 000000000..f7027c7e2 --- /dev/null +++ b/src/axolotl/prompt_strategies/context_qa.py @@ -0,0 +1,67 @@ +"""Module containing the classes for Context QA Prompt Tokenization Strategies""" +from typing import Tuple + +from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy +from axolotl.prompters import AlpacaPrompter, PromptStyle + + +# article, unanswerable_question, question, answer +def load_404(tokenizer, cfg): + return AlpacaMissingInfoContextPromptTokenizingStrategy( + AlpacaContextPrompter(PromptStyle.CHAT.value), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + + +def load(tokenizer, cfg): + return AlpacaContextPromptTokenizingStrategy( + AlpacaContextPrompter(PromptStyle.CHAT.value), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + + +class AlpacaContextPrompter(AlpacaPrompter): + """ + Customized system prompted for concise QA + """ + + system_prompt = ( + "Use the following contextual information to concisely answer the question.\n" + ) + system_no_input_prompt = ( + "Use the following contextual information to concisely answer the question.\n" + ) + + +class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + """ + Tokenization Strategy to combine in-context article with a question and answer + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: + return ( + prompt["article"] + "\n===\n" + prompt["question"], + "", + prompt["answer"], + ) + + +class AlpacaMissingInfoContextPromptTokenizingStrategy( + InstructionPromptTokenizingStrategy +): + """ + Tokenization Strategy to combine in-context article with a question that can't be answered + from the context and a default response to that effect + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: + return ( + prompt["article"] + "\n===\n" + prompt["unanswerable_question"], + "", + "The context provided does not contain any information about your inquiry. " + "Therefore, I'm unable to answer your question based on the given context.", + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 28ecebb14..58e0e97ec 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -234,6 +234,10 @@ def load_model( base_model, trust_remote_code=cfg.trust_remote_code or False, ) + # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this + # when training starts + if config.max_seq_len and cfg.sequence_len > config.max_seq_len: + config.max_seq_len = cfg.sequence_len model = AutoModelForCausalLM.from_pretrained( base_model, config=config,