diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b1c37d8ba..8d625e524 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -49,7 +49,7 @@ from axolotl.utils.collators import ( from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.schedulers import ( get_cosine_schedule_with_min_lr, - get_cosine_schedule_with_quadratic_warmup, + get_cosine_schedule_with_quadratic_warmup, JaggedLRRestartScheduler, ) try: @@ -129,7 +129,19 @@ class AxolotlTrainingArguments(TrainingArguments): ) relora_anneal_steps: Optional[int] = field( default=None, - metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, + metadata={"help": "how many anneal steps to take before reset for ReLoRA"}, + ) + jagged_restart_steps: Optional[int] = field( + default=None, + metadata={"help": "how often to reset for jagged restarts"}, + ) + jagged_restarts_warmup_steps: Optional[int] = field( + default=None, + metadata={"help": "how many warmup steps to take after reset for jagged restarts"}, + ) + jagged_restarts_anneal_steps: Optional[int] = field( + default=None, + metadata={"help": "how many anneal steps to take before reset for jagged restarts"}, ) bench_split: Optional[str] = field( default="eval", metadata={"help": "The benchmark split to run on"} @@ -226,7 +238,7 @@ class AxolotlTrainer(Trainer): min_lr_ratio=self.args.cosine_min_lr_ratio, ) else: - return super().create_scheduler(num_training_steps, optimizer) + super().create_scheduler(num_training_steps, optimizer) else: if use_cosine_quadratic: LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") @@ -234,6 +246,21 @@ class AxolotlTrainer(Trainer): if use_cosine_min_lr: LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") + if self.args.jagged_restart_steps: + warmup_steps = ( + self.args.jagged_restarts_warmup_steps or 10 + ) + anneal_steps = ( + self.args.jagged_restarts_anneal_steps or 1 + ) + self.lr_scheduler = JaggedLRRestartScheduler( + optimizer, + self.lr_scheduler, + self.args.jagged_restart_steps, + warmup_steps, + anneal_steps, + ) + return self.lr_scheduler def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: @@ -873,6 +900,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["optim"] = ( self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" ) + if self.cfg.save_only_model: + training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model training_arguments_kwargs["lr_scheduler_type"] = ( self.cfg.lr_scheduler if self.cfg.lr_scheduler diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py new file mode 100644 index 000000000..610f3607a --- /dev/null +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -0,0 +1,67 @@ +from typing import Optional, Dict, Any + +from axolotl.prompt_tokenizers import PromptTokenizingStrategy +from axolotl.prompters import Prompter +from axolotl.utils.chat_templates import chat_templates + + +class ChatTemplatePrompter(Prompter): + def __init__(self, tokenizer, chat_template=None, max_length=2048): + self.tokenizer = tokenizer + self.chat_template = chat_template + self.max_length = max_length + + def build_prompt(self, conversation, add_generation_prompt=False): + return self.tokenizer.apply_chat_template( + conversation, truncation=True, max_length=self.max_length, + add_generation_prompt=add_generation_prompt, + chat_template=self.chat_template, + ) + + +class ChatTemplateStrategy(PromptTokenizingStrategy): + """ + Tokenizing strategy for instruction-based prompts. + """ + + def tokenize_prompt(self, prompt): + turns = self.get_conversation_thread(prompt) + prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True) + input_ids = self.prompter.build_prompt(turns) + + if not self.train_on_inputs: + user_prompt_len = len(prompt_ids) + labels = [-100] * user_prompt_len + input_ids[user_prompt_len:] + else: + labels = input_ids + + + tokenized_prompt = { + "input_ids": input_ids, + "labels": labels, + "attention_mask": [1] * len(input_ids) + } + + return tokenized_prompt + + def get_conversation_thread(self, prompt): + conversations = prompt["conversations"] + # remap roles - allow for assistant turn + role_map = {"human": "user", "user": "user", "assistant": "assistant", "gpt": "assistant"} + turns = [ + {"role": role_map[t["from"]], "content": t["value"]} for t in conversations + ] + return turns + + +def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + tokenizer, + chat_templates(ds_cfg["conversation"]), + ), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + return strategy diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 9de266be1..25ccf17a5 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -62,7 +62,7 @@ class EvalFirstStepCallback( ): if ( args.evaluation_strategy == IntervalStrategy.STEPS - and args.eval_steps < 1.0 + and (args.eval_steps < 1.0 or args.eval_steps > 1) and state.global_step == 1 ): control.should_evaluate = True diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py index c49745c26..7ad754793 100644 --- a/src/axolotl/utils/schedulers.py +++ b/src/axolotl/utils/schedulers.py @@ -1,6 +1,7 @@ """Module for custom LRScheduler class""" import math from functools import partial +from typing import Sequence from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler @@ -140,3 +141,48 @@ def get_cosine_schedule_with_min_lr( min_lr_ratio=min_lr_ratio, ) return LambdaLR(optimizer, lr_lambda) + + +class JaggedLRRestartScheduler(LRScheduler): + """Wraps another scheduler to apply per-lora-restart learning rate warmups.""" + + def __init__( + self, + optimizer: Optimizer, + inner_schedule: LRScheduler, + jagged_restarts_steps: int, + jagged_restarts_warmup_steps: int, + jagged_restarts_anneal_steps: int = 1, + min_lr_scale: float = 0.001, + ) -> None: + self.inner_schedule = inner_schedule + self.restarts_steps = jagged_restarts_steps + self.warmup_steps = jagged_restarts_warmup_steps + self.anneal_steps = jagged_restarts_anneal_steps + self.min_lr_scale = min_lr_scale + super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose) + + def get_lr(self) -> float: + self.inner_schedule.last_epoch = self.last_epoch + + original = self.inner_schedule.get_lr() + step = self.last_epoch + + if step < self.restarts_steps: + scale = 1 + else: + per_relora_progress = step % self.restarts_steps + if per_relora_progress < self.warmup_steps: + cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps) + elif per_relora_progress > (self.restarts_steps - self.anneal_steps): + cycle_t = min( + 1.0, + (self.restarts_steps - per_relora_progress) / self.anneal_steps, + ) + else: + cycle_t = 1 + scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale + + if isinstance(original, Sequence): + return [lr * scale for lr in original] + return original * scale