Plugins create_lr_scheduler support (#2584)

* lr_scheduler support

* fix

* Update scheduler.py

* Update scheduler.py

* cfg handling

* black

* remove debug

* remove adding the axolotl cfg to the scheduler mixin

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Aleksandr Dremov
2025-04-29 23:08:30 +02:00
committed by GitHub
parent ecac731922
commit 41a1ec0c95
4 changed files with 40 additions and 18 deletions

View File

@@ -3,9 +3,10 @@
import logging
import torch
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
from transformers.trainer import Trainer
from axolotl.integrations.base import PluginManager
from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr,
@@ -25,9 +26,9 @@ class SchedulerMixin(Trainer):
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
) -> LRScheduler:
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
@@ -47,7 +48,16 @@ class SchedulerMixin(Trainer):
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if self.args.alternate_lr_scheduler_type == "one_cycle":
plugin_manager = PluginManager.get_instance()
lr_scheduler: LRScheduler | None = plugin_manager.create_lr_scheduler(
trainer=self,
optimizer=optimizer,
num_training_steps=num_training_steps
)
if lr_scheduler is not None:
LOG.info(f"Using plugin-created lr_scheduler: {lr_scheduler}")
self.lr_scheduler = lr_scheduler
elif self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {}
@@ -110,4 +120,4 @@ class SchedulerMixin(Trainer):
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler
return self.lr_scheduler # type: ignore

View File

@@ -1,6 +1,7 @@
"""Module for ReLoRA trainer"""
import torch
from torch.optim.lr_scheduler import LRScheduler
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.monkeypatch.relora import ReLoRAScheduler
@@ -19,9 +20,11 @@ class ReLoRATrainer(AxolotlTrainer):
self,
num_training_steps: int,
optimizer: torch.optim.Optimizer | None = None,
):
) -> LRScheduler:
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
lr_scheduler: LRScheduler = super().create_scheduler(
num_training_steps, optimizer
)
if self.args.relora_steps:
warmup_steps = (
@@ -30,7 +33,7 @@ class ReLoRATrainer(AxolotlTrainer):
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler(
self.lr_scheduler = ReLoRAScheduler( # type: ignore
optimizer,
lr_scheduler,
self.args.relora_steps,
@@ -38,6 +41,6 @@ class ReLoRATrainer(AxolotlTrainer):
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
self.lr_scheduler = lr_scheduler # type: ignore
return self.lr_scheduler
return self.lr_scheduler # type: ignore

View File

@@ -24,6 +24,7 @@ import logging
from typing import OrderedDict
import torch
from torch.optim.lr_scheduler import LRScheduler
class BasePlugin:
@@ -41,7 +42,7 @@ class BasePlugin:
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler.
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training.
"""
@@ -146,8 +147,8 @@ class BasePlugin:
"""
def create_lr_scheduler(
self, cfg, trainer, optimizer
): # pylint: disable=unused-argument
self, cfg, trainer, optimizer, num_training_steps
) -> LRScheduler | None: # pylint: disable=unused-argument
"""
Creates and returns a learning rate scheduler.
@@ -155,9 +156,10 @@ class BasePlugin:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
num_training_steps (int): Total number of training steps
Returns:
object: The created learning rate scheduler.
object (LRScheduler): The created learning rate scheduler.
"""
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
@@ -436,7 +438,9 @@ class PluginManager:
return optimizer
return None
def create_lr_scheduler(self, trainer, optimizer):
def create_lr_scheduler(
self, trainer, optimizer, num_training_steps
) -> LRScheduler | None:
"""
Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler.
@@ -448,7 +452,12 @@ class PluginManager:
object: The created learning rate scheduler, or None if none was found.
"""
for plugin in self.plugins.values():
scheduler = plugin.create_lr_scheduler(self.cfg, trainer, optimizer)
scheduler: LRScheduler | None = plugin.create_lr_scheduler(
self.cfg,
trainer=trainer,
optimizer=optimizer,
num_training_steps=num_training_steps,
)
if scheduler is not None:
return scheduler
return None

View File

@@ -72,7 +72,7 @@ class LogHooksPlugin(BasePlugin):
f.write("get_trainer_cls\n")
def create_lr_scheduler(
self, cfg, trainer, optimizer
self, cfg, trainer, optimizer, num_training_steps
): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
@@ -172,7 +172,7 @@ class TestPluginHooks:
assert "post_model_load" in file_contents
# assert "create_optimizer" in file_contents # not implemented yet
assert "get_trainer_cls" in file_contents
# assert "create_lr_scheduler" in file_contents # not implemented yet
assert "create_lr_scheduler" in file_contents
assert "add_callbacks_pre_trainer" in file_contents
assert "add_callbacks_post_trainer" in file_contents
assert "post_train" in file_contents