update create_optimizer for updated api

This commit is contained in:
Wing Lian
2026-02-19 23:49:32 -05:00
parent eb59070040
commit 3b5a9d1d88

View File

@@ -104,7 +104,7 @@ class OptimizerMixin(Trainer):
return optimizer_grouped_parameters
def create_optimizer(self):
def create_optimizer(self, model=None):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
@@ -112,9 +112,9 @@ class OptimizerMixin(Trainer):
and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None
):
return super().create_optimizer()
return super().create_optimizer(model=model)
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
opt_model = self.model if model is None else model
if (
not self.optimizer