update create_optimizer for updated api

This commit is contained in:
Wing Lian
2026-02-19 23:49:32 -05:00
parent eb59070040
commit 3b5a9d1d88

View File

@@ -104,7 +104,7 @@ class OptimizerMixin(Trainer):
return optimizer_grouped_parameters return optimizer_grouped_parameters
def create_optimizer(self): def create_optimizer(self, model=None):
if ( if (
self.args.loraplus_lr_ratio is None self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None and self.args.embedding_lr_scale is None
@@ -112,9 +112,9 @@ class OptimizerMixin(Trainer):
and self.args.lr_groups is None and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None and self.optimizer_cls_and_kwargs is None
): ):
return super().create_optimizer() return super().create_optimizer(model=model)
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model opt_model = self.model if model is None else model
if ( if (
not self.optimizer not self.optimizer