support for multi line inference input, log sweep over learning rates

This commit is contained in:
Wing Lian
2023-05-03 13:48:54 -04:00
parent 7748f3d6da
commit 9105935b00
3 changed files with 68 additions and 12 deletions

View File

@@ -1,5 +1,7 @@
import importlib
import logging
import os
import pathlib
import random
import signal
import sys
@@ -44,18 +46,20 @@ def choose_device(cfg):
cfg.device_map = {"": cfg.device}
def do_inference(cfg, model, tokenizer):
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
tokenizer.add_special_tokens({"unk_token": "<unk>"})
tokenizer.add_special_tokens({"bos_token": "<s>"})
tokenizer.add_special_tokens({"eos_token": "</s>"})
from axolotl.prompters import ReflectAlpacaPrompter
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
while True:
instruction = str(input("Give me an instruction: "))
# support for multiline inputs
print("Give me an instruction (Ctrl + D to finish): ")
instruction = pathlib.Path("/proc/self/fd/0").read_text()
if not instruction:
return
prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction)
prompt = prompter_module().build_prompt(instruction=instruction)
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
@@ -162,6 +166,10 @@ def train(
do_inference(cfg, model, tokenizer)
return
if "shard" in kwargs:
model.save_pretrained(cfg.output_dir)
return
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
@@ -207,12 +215,11 @@ def train(
logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}")
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
if cfg.local_rank == 0:
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
logging.info(
f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
)
model.save_pretrained(cfg.output_dir)
logging.info(
f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
)
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
trainer.save_model(cfg.output_dir)
if __name__ == "__main__":

View File

@@ -0,0 +1,34 @@
from torch.optim.lr_scheduler import LRScheduler
class InterpolatingLogScheduler(LRScheduler):
def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1):
"""A scheduler that interpolates learning rates in a logarithmic fashion
Args:
- optimizer: pytorch optimizer
- num_steps: int, the number of steps over which to increase from the min_lr to the max_lr
- min_lr: float, the minimum learning rate
- max_lr: float, the maximum learning rate
Usage:
fc = nn.Linear(1,1)
optimizer = optim.Adam(fc.parameters())
lr_scheduler = InterpolatingLogScheduler(optimizer, num_steps=400, min_lr=1e-6, max_lr=1e-4)
"""
self.num_steps = num_steps
self.min_lr = min_lr
self.max_lr = max_lr
self.q = (max_lr / min_lr) ** (1 / num_steps - 1)
super().__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch == 0:
lr = self.min_lr
elif self.last_epoch < self.num_steps:
# FIXME, not perfect as we need to account for number of steps are in an epoch, etc
lr = self.min_lr * (self.q ** self.last_epoch)
else:
lr = self.max_lr
return [lr for _ in self.base_lrs]

View File

@@ -12,6 +12,8 @@ from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.schedulers import InterpolatingLogScheduler
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
total_num_steps = int(
@@ -27,11 +29,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.logging_steps is not None
else max(min(int(0.005 * total_num_steps), 10), 1)
)
save_steps = eval_steps = (
save_steps = (
cfg.save_steps
if cfg.save_steps is not None
else min(int(0.05 * total_num_steps), 200)
)
eval_steps = (
cfg.eval_steps
if cfg.eval_steps is not None and save_steps % cfg.eval_steps == 0
else save_steps
)
training_arguments_kwargs = {}
if cfg.bf16 == "full":
@@ -95,7 +102,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
report_to="wandb" if cfg.use_wandb else None,
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
optim=cfg.optimizer if cfg.optimizer else None,
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler else None,
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
**training_arguments_kwargs,
)
@@ -147,8 +154,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
optimizer,
cfg.learning_rate,
total_steps=total_num_steps,
epochs=cfg.num_epochs,
**lr_scheduler_kwargs,
)
elif cfg.lr_scheduler == "log_sweep":
lr_scheduler = InterpolatingLogScheduler(
optimizer,
cfg.warmup_steps,
cfg.log_sweep_min_lr if cfg.log_sweep_min_lr else 1e-10,
cfg.log_sweep_max_lr if cfg.log_sweep_max_lr else 10,
)
else:
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
optimizer,