support for multi line inference input, log sweep over learning rates

This commit is contained in:
Wing Lian
2023-05-03 13:48:54 -04:00
parent 7748f3d6da
commit 9105935b00
3 changed files with 68 additions and 12 deletions

View File

@@ -0,0 +1,34 @@
from torch.optim.lr_scheduler import LRScheduler
class InterpolatingLogScheduler(LRScheduler):
def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1):
"""A scheduler that interpolates learning rates in a logarithmic fashion
Args:
- optimizer: pytorch optimizer
- num_steps: int, the number of steps over which to increase from the min_lr to the max_lr
- min_lr: float, the minimum learning rate
- max_lr: float, the maximum learning rate
Usage:
fc = nn.Linear(1,1)
optimizer = optim.Adam(fc.parameters())
lr_scheduler = InterpolatingLogScheduler(optimizer, num_steps=400, min_lr=1e-6, max_lr=1e-4)
"""
self.num_steps = num_steps
self.min_lr = min_lr
self.max_lr = max_lr
self.q = (max_lr / min_lr) ** (1 / num_steps - 1)
super().__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch == 0:
lr = self.min_lr
elif self.last_epoch < self.num_steps:
# FIXME, not perfect as we need to account for number of steps are in an epoch, etc
lr = self.min_lr * (self.q ** self.last_epoch)
else:
lr = self.max_lr
return [lr for _ in self.base_lrs]

View File

@@ -12,6 +12,8 @@ from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.schedulers import InterpolatingLogScheduler
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
total_num_steps = int(
@@ -27,11 +29,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.logging_steps is not None
else max(min(int(0.005 * total_num_steps), 10), 1)
)
save_steps = eval_steps = (
save_steps = (
cfg.save_steps
if cfg.save_steps is not None
else min(int(0.05 * total_num_steps), 200)
)
eval_steps = (
cfg.eval_steps
if cfg.eval_steps is not None and save_steps % cfg.eval_steps == 0
else save_steps
)
training_arguments_kwargs = {}
if cfg.bf16 == "full":
@@ -95,7 +102,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
report_to="wandb" if cfg.use_wandb else None,
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
optim=cfg.optimizer if cfg.optimizer else None,
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler else None,
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
**training_arguments_kwargs,
)
@@ -147,8 +154,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
optimizer,
cfg.learning_rate,
total_steps=total_num_steps,
epochs=cfg.num_epochs,
**lr_scheduler_kwargs,
)
elif cfg.lr_scheduler == "log_sweep":
lr_scheduler = InterpolatingLogScheduler(
optimizer,
cfg.warmup_steps,
cfg.log_sweep_min_lr if cfg.log_sweep_min_lr else 1e-10,
cfg.log_sweep_max_lr if cfg.log_sweep_max_lr else 10,
)
else:
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
optimizer,