support for multi line inference input, log sweep over learning rates
This commit is contained in:
@@ -1,5 +1,7 @@
|
|||||||
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import pathlib
|
||||||
import random
|
import random
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
@@ -44,18 +46,20 @@ def choose_device(cfg):
|
|||||||
cfg.device_map = {"": cfg.device}
|
cfg.device_map = {"": cfg.device}
|
||||||
|
|
||||||
|
|
||||||
def do_inference(cfg, model, tokenizer):
|
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
||||||
tokenizer.add_special_tokens({"unk_token": "<unk>"})
|
tokenizer.add_special_tokens({"unk_token": "<unk>"})
|
||||||
tokenizer.add_special_tokens({"bos_token": "<s>"})
|
tokenizer.add_special_tokens({"bos_token": "<s>"})
|
||||||
tokenizer.add_special_tokens({"eos_token": "</s>"})
|
tokenizer.add_special_tokens({"eos_token": "</s>"})
|
||||||
|
|
||||||
from axolotl.prompters import ReflectAlpacaPrompter
|
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
instruction = str(input("Give me an instruction: "))
|
# support for multiline inputs
|
||||||
|
print("Give me an instruction (Ctrl + D to finish): ")
|
||||||
|
instruction = pathlib.Path("/proc/self/fd/0").read_text()
|
||||||
if not instruction:
|
if not instruction:
|
||||||
return
|
return
|
||||||
prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction)
|
prompt = prompter_module().build_prompt(instruction=instruction)
|
||||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -162,6 +166,10 @@ def train(
|
|||||||
do_inference(cfg, model, tokenizer)
|
do_inference(cfg, model, tokenizer)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if "shard" in kwargs:
|
||||||
|
model.save_pretrained(cfg.output_dir)
|
||||||
|
return
|
||||||
|
|
||||||
train_dataset, eval_dataset = load_prepare_datasets(
|
train_dataset, eval_dataset = load_prepare_datasets(
|
||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||||
)
|
)
|
||||||
@@ -207,12 +215,11 @@ def train(
|
|||||||
logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}")
|
logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}")
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
logging.info(
|
||||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
|
||||||
logging.info(
|
)
|
||||||
f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
|
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||||
)
|
trainer.save_model(cfg.output_dir)
|
||||||
model.save_pretrained(cfg.output_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
34
src/axolotl/utils/schedulers.py
Normal file
34
src/axolotl/utils/schedulers.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
|
||||||
|
|
||||||
|
class InterpolatingLogScheduler(LRScheduler):
|
||||||
|
def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1):
|
||||||
|
"""A scheduler that interpolates learning rates in a logarithmic fashion
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- optimizer: pytorch optimizer
|
||||||
|
- num_steps: int, the number of steps over which to increase from the min_lr to the max_lr
|
||||||
|
- min_lr: float, the minimum learning rate
|
||||||
|
- max_lr: float, the maximum learning rate
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
fc = nn.Linear(1,1)
|
||||||
|
optimizer = optim.Adam(fc.parameters())
|
||||||
|
lr_scheduler = InterpolatingLogScheduler(optimizer, num_steps=400, min_lr=1e-6, max_lr=1e-4)
|
||||||
|
"""
|
||||||
|
self.num_steps = num_steps
|
||||||
|
self.min_lr = min_lr
|
||||||
|
self.max_lr = max_lr
|
||||||
|
self.q = (max_lr / min_lr) ** (1 / num_steps - 1)
|
||||||
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
if self.last_epoch == 0:
|
||||||
|
lr = self.min_lr
|
||||||
|
elif self.last_epoch < self.num_steps:
|
||||||
|
# FIXME, not perfect as we need to account for number of steps are in an epoch, etc
|
||||||
|
lr = self.min_lr * (self.q ** self.last_epoch)
|
||||||
|
else:
|
||||||
|
lr = self.max_lr
|
||||||
|
|
||||||
|
return [lr for _ in self.base_lrs]
|
||||||
@@ -12,6 +12,8 @@ from torch.optim.lr_scheduler import OneCycleLR
|
|||||||
from transformers import EarlyStoppingCallback
|
from transformers import EarlyStoppingCallback
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
|
||||||
|
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
@@ -27,11 +29,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
if cfg.logging_steps is not None
|
if cfg.logging_steps is not None
|
||||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||||
)
|
)
|
||||||
save_steps = eval_steps = (
|
save_steps = (
|
||||||
cfg.save_steps
|
cfg.save_steps
|
||||||
if cfg.save_steps is not None
|
if cfg.save_steps is not None
|
||||||
else min(int(0.05 * total_num_steps), 200)
|
else min(int(0.05 * total_num_steps), 200)
|
||||||
)
|
)
|
||||||
|
eval_steps = (
|
||||||
|
cfg.eval_steps
|
||||||
|
if cfg.eval_steps is not None and save_steps % cfg.eval_steps == 0
|
||||||
|
else save_steps
|
||||||
|
)
|
||||||
|
|
||||||
training_arguments_kwargs = {}
|
training_arguments_kwargs = {}
|
||||||
if cfg.bf16 == "full":
|
if cfg.bf16 == "full":
|
||||||
@@ -95,7 +102,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
report_to="wandb" if cfg.use_wandb else None,
|
report_to="wandb" if cfg.use_wandb else None,
|
||||||
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
||||||
optim=cfg.optimizer if cfg.optimizer else None,
|
optim=cfg.optimizer if cfg.optimizer else None,
|
||||||
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler else None,
|
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine",
|
||||||
weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
|
weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
|
||||||
**training_arguments_kwargs,
|
**training_arguments_kwargs,
|
||||||
)
|
)
|
||||||
@@ -147,8 +154,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
optimizer,
|
optimizer,
|
||||||
cfg.learning_rate,
|
cfg.learning_rate,
|
||||||
total_steps=total_num_steps,
|
total_steps=total_num_steps,
|
||||||
|
epochs=cfg.num_epochs,
|
||||||
**lr_scheduler_kwargs,
|
**lr_scheduler_kwargs,
|
||||||
)
|
)
|
||||||
|
elif cfg.lr_scheduler == "log_sweep":
|
||||||
|
lr_scheduler = InterpolatingLogScheduler(
|
||||||
|
optimizer,
|
||||||
|
cfg.warmup_steps,
|
||||||
|
cfg.log_sweep_min_lr if cfg.log_sweep_min_lr else 1e-10,
|
||||||
|
cfg.log_sweep_max_lr if cfg.log_sweep_max_lr else 10,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
|
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
|
||||||
optimizer,
|
optimizer,
|
||||||
|
|||||||
Reference in New Issue
Block a user