diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py index 1cd99bd9f..2a22d17e1 100644 --- a/src/axolotl/prompt_strategies/alpaca_chat.py +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -1,4 +1,4 @@ -from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy +from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, InstructionPromptTokenizingStrategy from axolotl.prompters import AlpacaPrompter, PromptStyle @@ -6,3 +6,18 @@ def load(tokenizer, cfg): return AlpacaPromptTokenizingStrategy( AlpacaPrompter(PromptStyle.chat), tokenizer, cfg.train_on_inputs, cfg.sequence_len ) + + +class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + def parse_instruction_fields(self, prompt) -> (str, str, str): + return ( + prompt["question"], + "", + prompt["answer"], + ) + + +def load_qa(tokenizer, cfg): + return AlpacaQAPromptTokenizingStrategy( + AlpacaPrompter(PromptStyle.chat), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 4336f740c..12e85e15e 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -23,7 +23,7 @@ class OneCycleLRSchedulerTrainer(Trainer): num_training_steps=num_training_steps pct_start = num_warmup_steps / num_training_steps - lr_scheduler = OneCycleLR( + self.lr_scheduler = OneCycleLR( optimizer, max_lr=self.args.learning_rate, total_steps=num_training_steps, @@ -31,7 +31,7 @@ class OneCycleLRSchedulerTrainer(Trainer): div_factor=6, ) - return lr_scheduler + return self.lr_scheduler def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):