Plugins create_lr_scheduler support (#2584)

* lr_scheduler support

* fix

* Update scheduler.py

* Update scheduler.py

* cfg handling

* black

* remove debug

* remove adding the axolotl cfg to the scheduler mixin

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Aleksandr Dremov
2025-04-29 23:08:30 +02:00
committed by GitHub
parent ecac731922
commit 41a1ec0c95
4 changed files with 40 additions and 18 deletions

View File

@@ -3,9 +3,10 @@
import logging import logging
import torch import torch
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
from transformers.trainer import Trainer from transformers.trainer import Trainer
from axolotl.integrations.base import PluginManager
from axolotl.utils.schedulers import ( from axolotl.utils.schedulers import (
RexLR, RexLR,
get_cosine_schedule_with_min_lr, get_cosine_schedule_with_min_lr,
@@ -25,9 +26,9 @@ class SchedulerMixin(Trainer):
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
): ) -> LRScheduler:
""" """
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument. passed as an argument.
Args: Args:
@@ -47,7 +48,16 @@ class SchedulerMixin(Trainer):
# fmt: off # fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on # fmt: on
if self.args.alternate_lr_scheduler_type == "one_cycle": plugin_manager = PluginManager.get_instance()
lr_scheduler: LRScheduler | None = plugin_manager.create_lr_scheduler(
trainer=self,
optimizer=optimizer,
num_training_steps=num_training_steps
)
if lr_scheduler is not None:
LOG.info(f"Using plugin-created lr_scheduler: {lr_scheduler}")
self.lr_scheduler = lr_scheduler
elif self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps) num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {} extra_lr_kwargs = {}
@@ -110,4 +120,4 @@ class SchedulerMixin(Trainer):
if use_cosine_min_lr: if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler return self.lr_scheduler # type: ignore

View File

@@ -1,6 +1,7 @@
"""Module for ReLoRA trainer""" """Module for ReLoRA trainer"""
import torch import torch
from torch.optim.lr_scheduler import LRScheduler
from axolotl.core.trainers.base import AxolotlTrainer from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.monkeypatch.relora import ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRAScheduler
@@ -19,9 +20,11 @@ class ReLoRATrainer(AxolotlTrainer):
self, self,
num_training_steps: int, num_training_steps: int,
optimizer: torch.optim.Optimizer | None = None, optimizer: torch.optim.Optimizer | None = None,
): ) -> LRScheduler:
optimizer = self.optimizer if optimizer is None else optimizer optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer) lr_scheduler: LRScheduler = super().create_scheduler(
num_training_steps, optimizer
)
if self.args.relora_steps: if self.args.relora_steps:
warmup_steps = ( warmup_steps = (
@@ -30,7 +33,7 @@ class ReLoRATrainer(AxolotlTrainer):
anneal_steps = ( anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1 self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
) )
self.lr_scheduler = ReLoRAScheduler( self.lr_scheduler = ReLoRAScheduler( # type: ignore
optimizer, optimizer,
lr_scheduler, lr_scheduler,
self.args.relora_steps, self.args.relora_steps,
@@ -38,6 +41,6 @@ class ReLoRATrainer(AxolotlTrainer):
warmup_steps, warmup_steps,
) )
else: else:
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler # type: ignore
return self.lr_scheduler return self.lr_scheduler # type: ignore

View File

@@ -24,6 +24,7 @@ import logging
from typing import OrderedDict from typing import OrderedDict
import torch import torch
from torch.optim.lr_scheduler import LRScheduler
class BasePlugin: class BasePlugin:
@@ -41,7 +42,7 @@ class BasePlugin:
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters. post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training. create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler. create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training. add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training. add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training.
""" """
@@ -146,8 +147,8 @@ class BasePlugin:
""" """
def create_lr_scheduler( def create_lr_scheduler(
self, cfg, trainer, optimizer self, cfg, trainer, optimizer, num_training_steps
): # pylint: disable=unused-argument ) -> LRScheduler | None: # pylint: disable=unused-argument
""" """
Creates and returns a learning rate scheduler. Creates and returns a learning rate scheduler.
@@ -155,9 +156,10 @@ class BasePlugin:
cfg (dict): The configuration for the plugin. cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training. trainer (object): The trainer object for training.
optimizer (object): The optimizer for training. optimizer (object): The optimizer for training.
num_training_steps (int): Total number of training steps
Returns: Returns:
object: The created learning rate scheduler. object (LRScheduler): The created learning rate scheduler.
""" """
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
@@ -436,7 +438,9 @@ class PluginManager:
return optimizer return optimizer
return None return None
def create_lr_scheduler(self, trainer, optimizer): def create_lr_scheduler(
self, trainer, optimizer, num_training_steps
) -> LRScheduler | None:
""" """
Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler. Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler.
@@ -448,7 +452,12 @@ class PluginManager:
object: The created learning rate scheduler, or None if none was found. object: The created learning rate scheduler, or None if none was found.
""" """
for plugin in self.plugins.values(): for plugin in self.plugins.values():
scheduler = plugin.create_lr_scheduler(self.cfg, trainer, optimizer) scheduler: LRScheduler | None = plugin.create_lr_scheduler(
self.cfg,
trainer=trainer,
optimizer=optimizer,
num_training_steps=num_training_steps,
)
if scheduler is not None: if scheduler is not None:
return scheduler return scheduler
return None return None

View File

@@ -72,7 +72,7 @@ class LogHooksPlugin(BasePlugin):
f.write("get_trainer_cls\n") f.write("get_trainer_cls\n")
def create_lr_scheduler( def create_lr_scheduler(
self, cfg, trainer, optimizer self, cfg, trainer, optimizer, num_training_steps
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
with open( with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
@@ -172,7 +172,7 @@ class TestPluginHooks:
assert "post_model_load" in file_contents assert "post_model_load" in file_contents
# assert "create_optimizer" in file_contents # not implemented yet # assert "create_optimizer" in file_contents # not implemented yet
assert "get_trainer_cls" in file_contents assert "get_trainer_cls" in file_contents
# assert "create_lr_scheduler" in file_contents # not implemented yet assert "create_lr_scheduler" in file_contents
assert "add_callbacks_pre_trainer" in file_contents assert "add_callbacks_pre_trainer" in file_contents
assert "add_callbacks_post_trainer" in file_contents assert "add_callbacks_post_trainer" in file_contents
assert "post_train" in file_contents assert "post_train" in file_contents