Add qa style data for alpaca instructions, fix one_cycle scheduler
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, InstructionPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||
|
||||
|
||||
@@ -6,3 +6,18 @@ def load(tokenizer, cfg):
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
AlpacaPrompter(PromptStyle.chat), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||
)
|
||||
|
||||
|
||||
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
return (
|
||||
prompt["question"],
|
||||
"",
|
||||
prompt["answer"],
|
||||
)
|
||||
|
||||
|
||||
def load_qa(tokenizer, cfg):
|
||||
return AlpacaQAPromptTokenizingStrategy(
|
||||
AlpacaPrompter(PromptStyle.chat), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||
)
|
||||
|
||||
@@ -23,7 +23,7 @@ class OneCycleLRSchedulerTrainer(Trainer):
|
||||
num_training_steps=num_training_steps
|
||||
pct_start = num_warmup_steps / num_training_steps
|
||||
|
||||
lr_scheduler = OneCycleLR(
|
||||
self.lr_scheduler = OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
total_steps=num_training_steps,
|
||||
@@ -31,7 +31,7 @@ class OneCycleLRSchedulerTrainer(Trainer):
|
||||
div_factor=6,
|
||||
)
|
||||
|
||||
return lr_scheduler
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
|
||||
Reference in New Issue
Block a user