fix: duplicate optim setting

This commit is contained in:
NanoCode012
2025-03-26 15:02:21 +07:00
parent 427e612d5a
commit 7c91cbddd3

View File

@@ -492,10 +492,10 @@ class TrainerBuilderBase(abc.ABC):
optim_args = self.cfg.optim_args
training_args_kwargs["optim_args"] = optim_args
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
if self.cfg.optim_target_modules:
training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules
@@ -706,21 +706,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optim_target_modules:
training_arguments_kwargs["optim_target_modules"] = (
self.cfg.optim_target_modules
)
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
@@ -1082,7 +1067,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
logging_first_step=True,
optim=self.cfg.optimizer,
**training_args_kwargs,
)