Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
d465b9fd98 wip, jagged restarts 2024-02-16 14:34:08 -05:00
4 changed files with 146 additions and 4 deletions

View File

@@ -49,7 +49,7 @@ from axolotl.utils.collators import (
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import ( from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr, get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup, get_cosine_schedule_with_quadratic_warmup, JaggedLRRestartScheduler,
) )
try: try:
@@ -129,7 +129,19 @@ class AxolotlTrainingArguments(TrainingArguments):
) )
relora_anneal_steps: Optional[int] = field( relora_anneal_steps: Optional[int] = field(
default=None, default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, metadata={"help": "how many anneal steps to take before reset for ReLoRA"},
)
jagged_restart_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for jagged restarts"},
)
jagged_restarts_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for jagged restarts"},
)
jagged_restarts_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many anneal steps to take before reset for jagged restarts"},
) )
bench_split: Optional[str] = field( bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"} default="eval", metadata={"help": "The benchmark split to run on"}
@@ -226,7 +238,7 @@ class AxolotlTrainer(Trainer):
min_lr_ratio=self.args.cosine_min_lr_ratio, min_lr_ratio=self.args.cosine_min_lr_ratio,
) )
else: else:
return super().create_scheduler(num_training_steps, optimizer) super().create_scheduler(num_training_steps, optimizer)
else: else:
if use_cosine_quadratic: if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
@@ -234,6 +246,21 @@ class AxolotlTrainer(Trainer):
if use_cosine_min_lr: if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
if self.args.jagged_restart_steps:
warmup_steps = (
self.args.jagged_restarts_warmup_steps or 10
)
anneal_steps = (
self.args.jagged_restarts_anneal_steps or 1
)
self.lr_scheduler = JaggedLRRestartScheduler(
optimizer,
self.lr_scheduler,
self.args.jagged_restart_steps,
warmup_steps,
anneal_steps,
)
return self.lr_scheduler return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
@@ -873,6 +900,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["optim"] = ( training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
) )
if self.cfg.save_only_model:
training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model
training_arguments_kwargs["lr_scheduler_type"] = ( training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler self.cfg.lr_scheduler
if self.cfg.lr_scheduler if self.cfg.lr_scheduler

View File

@@ -0,0 +1,67 @@
from typing import Optional, Dict, Any
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import chat_templates
class ChatTemplatePrompter(Prompter):
def __init__(self, tokenizer, chat_template=None, max_length=2048):
self.tokenizer = tokenizer
self.chat_template = chat_template
self.max_length = max_length
def build_prompt(self, conversation, add_generation_prompt=False):
return self.tokenizer.apply_chat_template(
conversation, truncation=True, max_length=self.max_length,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
)
class ChatTemplateStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for instruction-based prompts.
"""
def tokenize_prompt(self, prompt):
turns = self.get_conversation_thread(prompt)
prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True)
input_ids = self.prompter.build_prompt(turns)
if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
else:
labels = input_ids
tokenized_prompt = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids)
}
return tokenized_prompt
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap roles - allow for assistant turn
role_map = {"human": "user", "user": "user", "assistant": "assistant", "gpt": "assistant"}
turns = [
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
]
return turns
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_templates(ds_cfg["conversation"]),
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
return strategy

View File

@@ -62,7 +62,7 @@ class EvalFirstStepCallback(
): ):
if ( if (
args.evaluation_strategy == IntervalStrategy.STEPS args.evaluation_strategy == IntervalStrategy.STEPS
and args.eval_steps < 1.0 and (args.eval_steps < 1.0 or args.eval_steps > 1)
and state.global_step == 1 and state.global_step == 1
): ):
control.should_evaluate = True control.should_evaluate = True

View File

@@ -1,6 +1,7 @@
"""Module for custom LRScheduler class""" """Module for custom LRScheduler class"""
import math import math
from functools import partial from functools import partial
from typing import Sequence
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler from torch.optim.lr_scheduler import LambdaLR, LRScheduler
@@ -140,3 +141,48 @@ def get_cosine_schedule_with_min_lr(
min_lr_ratio=min_lr_ratio, min_lr_ratio=min_lr_ratio,
) )
return LambdaLR(optimizer, lr_lambda) return LambdaLR(optimizer, lr_lambda)
class JaggedLRRestartScheduler(LRScheduler):
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
def __init__(
self,
optimizer: Optimizer,
inner_schedule: LRScheduler,
jagged_restarts_steps: int,
jagged_restarts_warmup_steps: int,
jagged_restarts_anneal_steps: int = 1,
min_lr_scale: float = 0.001,
) -> None:
self.inner_schedule = inner_schedule
self.restarts_steps = jagged_restarts_steps
self.warmup_steps = jagged_restarts_warmup_steps
self.anneal_steps = jagged_restarts_anneal_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
def get_lr(self) -> float:
self.inner_schedule.last_epoch = self.last_epoch
original = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.restarts_steps:
scale = 1
else:
per_relora_progress = step % self.restarts_steps
if per_relora_progress < self.warmup_steps:
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
elif per_relora_progress > (self.restarts_steps - self.anneal_steps):
cycle_t = min(
1.0,
(self.restarts_steps - per_relora_progress) / self.anneal_steps,
)
else:
cycle_t = 1
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
if isinstance(original, Sequence):
return [lr * scale for lr in original]
return original * scale