diff --git a/src/axolotl/core/trainers/mixins/scheduler.py b/src/axolotl/core/trainers/mixins/scheduler.py index fc2b0e59d..23e4f6c25 100644 --- a/src/axolotl/core/trainers/mixins/scheduler.py +++ b/src/axolotl/core/trainers/mixins/scheduler.py @@ -25,7 +25,7 @@ class SchedulerMixin(Trainer): args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] def create_scheduler( - self, num_training_steps: int, optimizer: torch.optim.Optimizer = None + self, num_training_steps: int, optimizer: None | torch.optim.Optimizer = None ) -> LRScheduler: """ Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or @@ -45,6 +45,13 @@ class SchedulerMixin(Trainer): and self.args.cosine_min_lr_ratio is not None ) + if optimizer is None: + if self.optimizer is None: + raise ValueError( + "Optimizer must be set before calling create_scheduler or passed as an argument." + ) + optimizer = self.optimizer + # fmt: off if self.lr_scheduler is None: # type: ignore # fmt: on