From 9105935b0065e18466d7aaff4d51fde05353ca27 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 3 May 2023 13:48:54 -0400 Subject: [PATCH] support for multi line inference input, log sweep over learning rates --- scripts/finetune.py | 27 ++++++++++++++++---------- src/axolotl/utils/schedulers.py | 34 +++++++++++++++++++++++++++++++++ src/axolotl/utils/trainer.py | 19 ++++++++++++++++-- 3 files changed, 68 insertions(+), 12 deletions(-) create mode 100644 src/axolotl/utils/schedulers.py diff --git a/scripts/finetune.py b/scripts/finetune.py index cf740e00e..a8cfe2a03 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -1,5 +1,7 @@ +import importlib import logging import os +import pathlib import random import signal import sys @@ -44,18 +46,20 @@ def choose_device(cfg): cfg.device_map = {"": cfg.device} -def do_inference(cfg, model, tokenizer): +def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): tokenizer.add_special_tokens({"unk_token": ""}) tokenizer.add_special_tokens({"bos_token": ""}) tokenizer.add_special_tokens({"eos_token": ""}) - from axolotl.prompters import ReflectAlpacaPrompter + prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter) while True: - instruction = str(input("Give me an instruction: ")) + # support for multiline inputs + print("Give me an instruction (Ctrl + D to finish): ") + instruction = pathlib.Path("/proc/self/fd/0").read_text() if not instruction: return - prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction) + prompt = prompter_module().build_prompt(instruction=instruction) batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) model.eval() @@ -162,6 +166,10 @@ def train( do_inference(cfg, model, tokenizer) return + if "shard" in kwargs: + model.save_pretrained(cfg.output_dir) + return + train_dataset, eval_dataset = load_prepare_datasets( tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) @@ -207,12 +215,11 @@ def train( logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}") trainer.train(resume_from_checkpoint=resume_from_checkpoint) - if cfg.local_rank == 0: - # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading - logging.info( - f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}" - ) - model.save_pretrained(cfg.output_dir) + logging.info( + f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}" + ) + # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading + trainer.save_model(cfg.output_dir) if __name__ == "__main__": diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py new file mode 100644 index 000000000..4a90436e7 --- /dev/null +++ b/src/axolotl/utils/schedulers.py @@ -0,0 +1,34 @@ +from torch.optim.lr_scheduler import LRScheduler + + +class InterpolatingLogScheduler(LRScheduler): + def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1): + """A scheduler that interpolates learning rates in a logarithmic fashion + + Args: + - optimizer: pytorch optimizer + - num_steps: int, the number of steps over which to increase from the min_lr to the max_lr + - min_lr: float, the minimum learning rate + - max_lr: float, the maximum learning rate + + Usage: + fc = nn.Linear(1,1) + optimizer = optim.Adam(fc.parameters()) + lr_scheduler = InterpolatingLogScheduler(optimizer, num_steps=400, min_lr=1e-6, max_lr=1e-4) + """ + self.num_steps = num_steps + self.min_lr = min_lr + self.max_lr = max_lr + self.q = (max_lr / min_lr) ** (1 / num_steps - 1) + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch == 0: + lr = self.min_lr + elif self.last_epoch < self.num_steps: + # FIXME, not perfect as we need to account for number of steps are in an epoch, etc + lr = self.min_lr * (self.q ** self.last_epoch) + else: + lr = self.max_lr + + return [lr for _ in self.base_lrs] diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 73be3dbd2..63c6856b7 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -12,6 +12,8 @@ from torch.optim.lr_scheduler import OneCycleLR from transformers import EarlyStoppingCallback from transformers.trainer_pt_utils import get_parameter_names +from axolotl.utils.schedulers import InterpolatingLogScheduler + def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): total_num_steps = int( @@ -27,11 +29,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.logging_steps is not None else max(min(int(0.005 * total_num_steps), 10), 1) ) - save_steps = eval_steps = ( + save_steps = ( cfg.save_steps if cfg.save_steps is not None else min(int(0.05 * total_num_steps), 200) ) + eval_steps = ( + cfg.eval_steps + if cfg.eval_steps is not None and save_steps % cfg.eval_steps == 0 + else save_steps + ) training_arguments_kwargs = {} if cfg.bf16 == "full": @@ -95,7 +102,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): report_to="wandb" if cfg.use_wandb else None, run_name=cfg.wandb_run_id if cfg.use_wandb else None, optim=cfg.optimizer if cfg.optimizer else None, - lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler else None, + lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine", weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0, **training_arguments_kwargs, ) @@ -147,8 +154,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): optimizer, cfg.learning_rate, total_steps=total_num_steps, + epochs=cfg.num_epochs, **lr_scheduler_kwargs, ) + elif cfg.lr_scheduler == "log_sweep": + lr_scheduler = InterpolatingLogScheduler( + optimizer, + cfg.warmup_steps, + cfg.log_sweep_min_lr if cfg.log_sweep_min_lr else 1e-10, + cfg.log_sweep_max_lr if cfg.log_sweep_max_lr else 10, + ) else: lr_scheduler = transformers.get_cosine_schedule_with_warmup( optimizer,