Compare commits
1 Commits
4bit-optim
...
20240216-u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d465b9fd98 |
@@ -49,7 +49,7 @@ from axolotl.utils.collators import (
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.schedulers import (
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
get_cosine_schedule_with_quadratic_warmup, JaggedLRRestartScheduler,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -129,7 +129,19 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
)
|
||||
relora_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
metadata={"help": "how many anneal steps to take before reset for ReLoRA"},
|
||||
)
|
||||
jagged_restart_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for jagged restarts"},
|
||||
)
|
||||
jagged_restarts_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for jagged restarts"},
|
||||
)
|
||||
jagged_restarts_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many anneal steps to take before reset for jagged restarts"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
@@ -226,7 +238,7 @@ class AxolotlTrainer(Trainer):
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
super().create_scheduler(num_training_steps, optimizer)
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||
@@ -234,6 +246,21 @@ class AxolotlTrainer(Trainer):
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||
|
||||
if self.args.jagged_restart_steps:
|
||||
warmup_steps = (
|
||||
self.args.jagged_restarts_warmup_steps or 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.jagged_restarts_anneal_steps or 1
|
||||
)
|
||||
self.lr_scheduler = JaggedLRRestartScheduler(
|
||||
optimizer,
|
||||
self.lr_scheduler,
|
||||
self.args.jagged_restart_steps,
|
||||
warmup_steps,
|
||||
anneal_steps,
|
||||
)
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
@@ -873,6 +900,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["optim"] = (
|
||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||
)
|
||||
if self.cfg.save_only_model:
|
||||
training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model
|
||||
training_arguments_kwargs["lr_scheduler_type"] = (
|
||||
self.cfg.lr_scheduler
|
||||
if self.cfg.lr_scheduler
|
||||
|
||||
67
src/axolotl/prompt_strategies/chat_template.py
Normal file
67
src/axolotl/prompt_strategies/chat_template.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import Prompter
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
|
||||
|
||||
class ChatTemplatePrompter(Prompter):
|
||||
def __init__(self, tokenizer, chat_template=None, max_length=2048):
|
||||
self.tokenizer = tokenizer
|
||||
self.chat_template = chat_template
|
||||
self.max_length = max_length
|
||||
|
||||
def build_prompt(self, conversation, add_generation_prompt=False):
|
||||
return self.tokenizer.apply_chat_template(
|
||||
conversation, truncation=True, max_length=self.max_length,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
chat_template=self.chat_template,
|
||||
)
|
||||
|
||||
|
||||
class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for instruction-based prompts.
|
||||
"""
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
turns = self.get_conversation_thread(prompt)
|
||||
prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True)
|
||||
input_ids = self.prompter.build_prompt(turns)
|
||||
|
||||
if not self.train_on_inputs:
|
||||
user_prompt_len = len(prompt_ids)
|
||||
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
|
||||
else:
|
||||
labels = input_ids
|
||||
|
||||
|
||||
tokenized_prompt = {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": [1] * len(input_ids)
|
||||
}
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
# remap roles - allow for assistant turn
|
||||
role_map = {"human": "user", "user": "user", "assistant": "assistant", "gpt": "assistant"}
|
||||
turns = [
|
||||
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_templates(ds_cfg["conversation"]),
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
return strategy
|
||||
@@ -62,7 +62,7 @@ class EvalFirstStepCallback(
|
||||
):
|
||||
if (
|
||||
args.evaluation_strategy == IntervalStrategy.STEPS
|
||||
and args.eval_steps < 1.0
|
||||
and (args.eval_steps < 1.0 or args.eval_steps > 1)
|
||||
and state.global_step == 1
|
||||
):
|
||||
control.should_evaluate = True
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Module for custom LRScheduler class"""
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
@@ -140,3 +141,48 @@ def get_cosine_schedule_with_min_lr(
|
||||
min_lr_ratio=min_lr_ratio,
|
||||
)
|
||||
return LambdaLR(optimizer, lr_lambda)
|
||||
|
||||
|
||||
class JaggedLRRestartScheduler(LRScheduler):
|
||||
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
inner_schedule: LRScheduler,
|
||||
jagged_restarts_steps: int,
|
||||
jagged_restarts_warmup_steps: int,
|
||||
jagged_restarts_anneal_steps: int = 1,
|
||||
min_lr_scale: float = 0.001,
|
||||
) -> None:
|
||||
self.inner_schedule = inner_schedule
|
||||
self.restarts_steps = jagged_restarts_steps
|
||||
self.warmup_steps = jagged_restarts_warmup_steps
|
||||
self.anneal_steps = jagged_restarts_anneal_steps
|
||||
self.min_lr_scale = min_lr_scale
|
||||
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
||||
|
||||
def get_lr(self) -> float:
|
||||
self.inner_schedule.last_epoch = self.last_epoch
|
||||
|
||||
original = self.inner_schedule.get_lr()
|
||||
step = self.last_epoch
|
||||
|
||||
if step < self.restarts_steps:
|
||||
scale = 1
|
||||
else:
|
||||
per_relora_progress = step % self.restarts_steps
|
||||
if per_relora_progress < self.warmup_steps:
|
||||
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
|
||||
elif per_relora_progress > (self.restarts_steps - self.anneal_steps):
|
||||
cycle_t = min(
|
||||
1.0,
|
||||
(self.restarts_steps - per_relora_progress) / self.anneal_steps,
|
||||
)
|
||||
else:
|
||||
cycle_t = 1
|
||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
||||
|
||||
if isinstance(original, Sequence):
|
||||
return [lr * scale for lr in original]
|
||||
return original * scale
|
||||
|
||||
Reference in New Issue
Block a user