bug-fix: use self.optimizer if optimizer not passed to SchedulerMixin.create_scheduler() (#3435) [skip ci]
* bug-fix: use self.optimizer if optimizer not passed to SchedulerMixin.create_scheduler() * nit: raise if self.optimizer is also unset * optimizer properly optional in create_scheduler()
This commit is contained in:
@@ -25,7 +25,7 @@ class SchedulerMixin(Trainer):
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
self, num_training_steps: int, optimizer: None | torch.optim.Optimizer = None
|
||||
) -> LRScheduler:
|
||||
"""
|
||||
Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
@@ -45,6 +45,13 @@ class SchedulerMixin(Trainer):
|
||||
and self.args.cosine_min_lr_ratio is not None
|
||||
)
|
||||
|
||||
if optimizer is None:
|
||||
if self.optimizer is None:
|
||||
raise ValueError(
|
||||
"Optimizer must be set before calling create_scheduler or passed as an argument."
|
||||
)
|
||||
optimizer = self.optimizer
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore
|
||||
# fmt: on
|
||||
|
||||
Reference in New Issue
Block a user