Add qa style data for alpaca instructions, fix one_cycle scheduler
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, InstructionPromptTokenizingStrategy
|
||||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||||
|
|
||||||
|
|
||||||
@@ -6,3 +6,18 @@ def load(tokenizer, cfg):
|
|||||||
return AlpacaPromptTokenizingStrategy(
|
return AlpacaPromptTokenizingStrategy(
|
||||||
AlpacaPrompter(PromptStyle.chat), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
AlpacaPrompter(PromptStyle.chat), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||||
|
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||||
|
return (
|
||||||
|
prompt["question"],
|
||||||
|
"",
|
||||||
|
prompt["answer"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_qa(tokenizer, cfg):
|
||||||
|
return AlpacaQAPromptTokenizingStrategy(
|
||||||
|
AlpacaPrompter(PromptStyle.chat), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||||
|
)
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class OneCycleLRSchedulerTrainer(Trainer):
|
|||||||
num_training_steps=num_training_steps
|
num_training_steps=num_training_steps
|
||||||
pct_start = num_warmup_steps / num_training_steps
|
pct_start = num_warmup_steps / num_training_steps
|
||||||
|
|
||||||
lr_scheduler = OneCycleLR(
|
self.lr_scheduler = OneCycleLR(
|
||||||
optimizer,
|
optimizer,
|
||||||
max_lr=self.args.learning_rate,
|
max_lr=self.args.learning_rate,
|
||||||
total_steps=num_training_steps,
|
total_steps=num_training_steps,
|
||||||
@@ -31,7 +31,7 @@ class OneCycleLRSchedulerTrainer(Trainer):
|
|||||||
div_factor=6,
|
div_factor=6,
|
||||||
)
|
)
|
||||||
|
|
||||||
return lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||||
|
|||||||
Reference in New Issue
Block a user