fix: duplicate optim setting
This commit is contained in:
@@ -492,10 +492,10 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
optim_args = self.cfg.optim_args
|
optim_args = self.cfg.optim_args
|
||||||
training_args_kwargs["optim_args"] = optim_args
|
training_args_kwargs["optim_args"] = optim_args
|
||||||
|
|
||||||
if self.cfg.optimizer == "adamw_anyprecision":
|
if self.cfg.optimizer == "adamw_anyprecision":
|
||||||
if Path(self.cfg.torchdistx_path).exists():
|
if Path(self.cfg.torchdistx_path).exists():
|
||||||
sys.path.append(self.cfg.torchdistx_path)
|
sys.path.append(self.cfg.torchdistx_path)
|
||||||
importlib.import_module("torchdistx")
|
importlib.import_module("torchdistx")
|
||||||
|
|
||||||
if self.cfg.optim_target_modules:
|
if self.cfg.optim_target_modules:
|
||||||
training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules
|
training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules
|
||||||
@@ -706,21 +706,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||||
|
|
||||||
training_arguments_kwargs["optim"] = (
|
|
||||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
|
||||||
)
|
|
||||||
if self.cfg.optim_args:
|
|
||||||
if isinstance(self.cfg.optim_args, dict):
|
|
||||||
optim_args = ",".join(
|
|
||||||
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
optim_args = self.cfg.optim_args
|
|
||||||
training_arguments_kwargs["optim_args"] = optim_args
|
|
||||||
if self.cfg.optim_target_modules:
|
|
||||||
training_arguments_kwargs["optim_target_modules"] = (
|
|
||||||
self.cfg.optim_target_modules
|
|
||||||
)
|
|
||||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||||
@@ -1082,7 +1067,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||||
logging_first_step=True,
|
logging_first_step=True,
|
||||||
optim=self.cfg.optimizer,
|
|
||||||
**training_args_kwargs,
|
**training_args_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user