update create_optimizer for updated api
This commit is contained in:
@@ -104,7 +104,7 @@ class OptimizerMixin(Trainer):
|
|||||||
|
|
||||||
return optimizer_grouped_parameters
|
return optimizer_grouped_parameters
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self, model=None):
|
||||||
if (
|
if (
|
||||||
self.args.loraplus_lr_ratio is None
|
self.args.loraplus_lr_ratio is None
|
||||||
and self.args.embedding_lr_scale is None
|
and self.args.embedding_lr_scale is None
|
||||||
@@ -112,9 +112,9 @@ class OptimizerMixin(Trainer):
|
|||||||
and self.args.lr_groups is None
|
and self.args.lr_groups is None
|
||||||
and self.optimizer_cls_and_kwargs is None
|
and self.optimizer_cls_and_kwargs is None
|
||||||
):
|
):
|
||||||
return super().create_optimizer()
|
return super().create_optimizer(model=model)
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
opt_model = self.model if model is None else model
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not self.optimizer
|
not self.optimizer
|
||||||
|
|||||||
Reference in New Issue
Block a user