fix: add missing weight_decay handling
This commit is contained in:
@@ -318,6 +318,7 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
"save_safetensors",
|
"save_safetensors",
|
||||||
"save_only_model",
|
"save_only_model",
|
||||||
"include_tokens_per_second",
|
"include_tokens_per_second",
|
||||||
|
"weight_decay",
|
||||||
]:
|
]:
|
||||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||||
@@ -601,10 +602,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||||
|
|
||||||
training_arguments_kwargs["weight_decay"] = (
|
|
||||||
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
||||||
training_arguments_kwargs["multipack_real_batches"] = (
|
training_arguments_kwargs["multipack_real_batches"] = (
|
||||||
not self.cfg.flash_attention or self.cfg.multipack_real_batches
|
not self.cfg.flash_attention or self.cfg.multipack_real_batches
|
||||||
@@ -1004,6 +1001,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
training_args_kwargs["remove_unused_columns"] = False
|
training_args_kwargs["remove_unused_columns"] = False
|
||||||
|
|
||||||
|
# only rlhf
|
||||||
if self.cfg.dataset_processes:
|
if self.cfg.dataset_processes:
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user