fix: add missing weight_decay handling

This commit is contained in:
NanoCode012
2025-02-21 13:55:17 +07:00
parent 8b6db0c72d
commit edaec9fe98

View File

@@ -318,6 +318,7 @@ class TrainerBuilderBase(abc.ABC):
"save_safetensors",
"save_only_model",
"include_tokens_per_second",
"weight_decay",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
@@ -601,10 +602,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
)
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
training_arguments_kwargs["multipack_real_batches"] = (
not self.cfg.flash_attention or self.cfg.multipack_real_batches
@@ -1004,6 +1001,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else:
training_args_kwargs["remove_unused_columns"] = False
# only rlhf
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes