bug-fix: use self.optimizer if optimizer not passed to SchedulerMixin.create_scheduler() (#3435) [skip ci]

* bug-fix: use self.optimizer if optimizer not passed to SchedulerMixin.create_scheduler()

* nit: raise if self.optimizer is also unset

* optimizer properly optional in create_scheduler()
This commit is contained in:
kallewoof
2026-03-02 17:30:07 +09:00
committed by GitHub
parent 18f26c19ef
commit 7f23b302d1

View File

@@ -25,7 +25,7 @@ class SchedulerMixin(Trainer):
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None self, num_training_steps: int, optimizer: None | torch.optim.Optimizer = None
) -> LRScheduler: ) -> LRScheduler:
""" """
Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or
@@ -45,6 +45,13 @@ class SchedulerMixin(Trainer):
and self.args.cosine_min_lr_ratio is not None and self.args.cosine_min_lr_ratio is not None
) )
if optimizer is None:
if self.optimizer is None:
raise ValueError(
"Optimizer must be set before calling create_scheduler or passed as an argument."
)
optimizer = self.optimizer
# fmt: off # fmt: off
if self.lr_scheduler is None: # type: ignore if self.lr_scheduler is None: # type: ignore
# fmt: on # fmt: on