diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index d79801a6a..13bac93b7 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -492,10 +492,10 @@ class TrainerBuilderBase(abc.ABC): optim_args = self.cfg.optim_args training_args_kwargs["optim_args"] = optim_args - if self.cfg.optimizer == "adamw_anyprecision": - if Path(self.cfg.torchdistx_path).exists(): - sys.path.append(self.cfg.torchdistx_path) - importlib.import_module("torchdistx") + if self.cfg.optimizer == "adamw_anyprecision": + if Path(self.cfg.torchdistx_path).exists(): + sys.path.append(self.cfg.torchdistx_path) + importlib.import_module("torchdistx") if self.cfg.optim_target_modules: training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules @@ -706,21 +706,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling - training_arguments_kwargs["optim"] = ( - self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" - ) - if self.cfg.optim_args: - if isinstance(self.cfg.optim_args, dict): - optim_args = ",".join( - [f"{key}={value}" for key, value in self.cfg.optim_args.items()] - ) - else: - optim_args = self.cfg.optim_args - training_arguments_kwargs["optim_args"] = optim_args - if self.cfg.optim_target_modules: - training_arguments_kwargs["optim_target_modules"] = ( - self.cfg.optim_target_modules - ) training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups @@ -1082,7 +1067,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg logging_first_step=True, - optim=self.cfg.optimizer, **training_args_kwargs, )