diff --git a/src/axolotl/core/trainers/mixins/scheduler.py b/src/axolotl/core/trainers/mixins/scheduler.py index b0a5ee895..0c36f9f95 100644 --- a/src/axolotl/core/trainers/mixins/scheduler.py +++ b/src/axolotl/core/trainers/mixins/scheduler.py @@ -3,9 +3,10 @@ import logging import torch -from torch.optim.lr_scheduler import OneCycleLR +from torch.optim.lr_scheduler import LRScheduler, OneCycleLR from transformers.trainer import Trainer +from axolotl.integrations.base import PluginManager from axolotl.utils.schedulers import ( RexLR, get_cosine_schedule_with_min_lr, @@ -25,9 +26,9 @@ class SchedulerMixin(Trainer): def create_scheduler( self, num_training_steps: int, optimizer: torch.optim.Optimizer = None - ): + ) -> LRScheduler: """ - Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or + Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or passed as an argument. Args: @@ -47,7 +48,16 @@ class SchedulerMixin(Trainer): # fmt: off if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition # fmt: on - if self.args.alternate_lr_scheduler_type == "one_cycle": + plugin_manager = PluginManager.get_instance() + lr_scheduler: LRScheduler | None = plugin_manager.create_lr_scheduler( + trainer=self, + optimizer=optimizer, + num_training_steps=num_training_steps + ) + if lr_scheduler is not None: + LOG.info(f"Using plugin-created lr_scheduler: {lr_scheduler}") + self.lr_scheduler = lr_scheduler + elif self.args.alternate_lr_scheduler_type == "one_cycle": num_warmup_steps = self.args.get_warmup_steps(num_training_steps) pct_start = num_warmup_steps / num_training_steps extra_lr_kwargs = {} @@ -110,4 +120,4 @@ class SchedulerMixin(Trainer): if use_cosine_min_lr: LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") - return self.lr_scheduler + return self.lr_scheduler # type: ignore diff --git a/src/axolotl/core/trainers/relora.py b/src/axolotl/core/trainers/relora.py index 3bcd4a9b8..890278f49 100644 --- a/src/axolotl/core/trainers/relora.py +++ b/src/axolotl/core/trainers/relora.py @@ -1,6 +1,7 @@ """Module for ReLoRA trainer""" import torch +from torch.optim.lr_scheduler import LRScheduler from axolotl.core.trainers.base import AxolotlTrainer from axolotl.monkeypatch.relora import ReLoRAScheduler @@ -19,9 +20,11 @@ class ReLoRATrainer(AxolotlTrainer): self, num_training_steps: int, optimizer: torch.optim.Optimizer | None = None, - ): + ) -> LRScheduler: optimizer = self.optimizer if optimizer is None else optimizer - lr_scheduler = super().create_scheduler(num_training_steps, optimizer) + lr_scheduler: LRScheduler = super().create_scheduler( + num_training_steps, optimizer + ) if self.args.relora_steps: warmup_steps = ( @@ -30,7 +33,7 @@ class ReLoRATrainer(AxolotlTrainer): anneal_steps = ( self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1 ) - self.lr_scheduler = ReLoRAScheduler( + self.lr_scheduler = ReLoRAScheduler( # type: ignore optimizer, lr_scheduler, self.args.relora_steps, @@ -38,6 +41,6 @@ class ReLoRATrainer(AxolotlTrainer): warmup_steps, ) else: - self.lr_scheduler = lr_scheduler + self.lr_scheduler = lr_scheduler # type: ignore - return self.lr_scheduler + return self.lr_scheduler # type: ignore diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 7d6491478..efe542af7 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -24,6 +24,7 @@ import logging from typing import OrderedDict import torch +from torch.optim.lr_scheduler import LRScheduler class BasePlugin: @@ -41,7 +42,7 @@ class BasePlugin: post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters. create_optimizer(cfg, trainer): Creates and returns an optimizer for training. - create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler. + create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler. add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training. add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training. """ @@ -146,8 +147,8 @@ class BasePlugin: """ def create_lr_scheduler( - self, cfg, trainer, optimizer - ): # pylint: disable=unused-argument + self, cfg, trainer, optimizer, num_training_steps + ) -> LRScheduler | None: # pylint: disable=unused-argument """ Creates and returns a learning rate scheduler. @@ -155,9 +156,10 @@ class BasePlugin: cfg (dict): The configuration for the plugin. trainer (object): The trainer object for training. optimizer (object): The optimizer for training. + num_training_steps (int): Total number of training steps Returns: - object: The created learning rate scheduler. + object (LRScheduler): The created learning rate scheduler. """ def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument @@ -436,7 +438,9 @@ class PluginManager: return optimizer return None - def create_lr_scheduler(self, trainer, optimizer): + def create_lr_scheduler( + self, trainer, optimizer, num_training_steps + ) -> LRScheduler | None: """ Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler. @@ -448,7 +452,12 @@ class PluginManager: object: The created learning rate scheduler, or None if none was found. """ for plugin in self.plugins.values(): - scheduler = plugin.create_lr_scheduler(self.cfg, trainer, optimizer) + scheduler: LRScheduler | None = plugin.create_lr_scheduler( + self.cfg, + trainer=trainer, + optimizer=optimizer, + num_training_steps=num_training_steps, + ) if scheduler is not None: return scheduler return None diff --git a/tests/e2e/integrations/test_hooks.py b/tests/e2e/integrations/test_hooks.py index e51334dfe..9b12e6d4e 100644 --- a/tests/e2e/integrations/test_hooks.py +++ b/tests/e2e/integrations/test_hooks.py @@ -72,7 +72,7 @@ class LogHooksPlugin(BasePlugin): f.write("get_trainer_cls\n") def create_lr_scheduler( - self, cfg, trainer, optimizer + self, cfg, trainer, optimizer, num_training_steps ): # pylint: disable=unused-argument with open( self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" @@ -172,7 +172,7 @@ class TestPluginHooks: assert "post_model_load" in file_contents # assert "create_optimizer" in file_contents # not implemented yet assert "get_trainer_cls" in file_contents - # assert "create_lr_scheduler" in file_contents # not implemented yet + assert "create_lr_scheduler" in file_contents assert "add_callbacks_pre_trainer" in file_contents assert "add_callbacks_post_trainer" in file_contents assert "post_train" in file_contents