Add qa style data for alpaca instructions, fix one_cycle scheduler

This commit is contained in:
Wing Lian
2023-05-22 22:58:10 -04:00
parent b029a11e65
commit 3a503770e4
2 changed files with 18 additions and 3 deletions

View File

@@ -1,4 +1,4 @@
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, InstructionPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle
@@ -6,3 +6,18 @@ def load(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.chat), tokenizer, cfg.train_on_inputs, cfg.sequence_len
)
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
return (
prompt["question"],
"",
prompt["answer"],
)
def load_qa(tokenizer, cfg):
return AlpacaQAPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.chat), tokenizer, cfg.train_on_inputs, cfg.sequence_len
)

View File

@@ -23,7 +23,7 @@ class OneCycleLRSchedulerTrainer(Trainer):
num_training_steps=num_training_steps
pct_start = num_warmup_steps / num_training_steps
lr_scheduler = OneCycleLR(
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
@@ -31,7 +31,7 @@ class OneCycleLRSchedulerTrainer(Trainer):
div_factor=6,
)
return lr_scheduler
return self.lr_scheduler
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):