From edaec9fe988fba4bb0d11129e904fdf5caa967ca Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 21 Feb 2025 13:55:17 +0700 Subject: [PATCH] fix: add missing weight_decay handling --- src/axolotl/core/trainer_builder.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4517fe102..abced8ceb 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -318,6 +318,7 @@ class TrainerBuilderBase(abc.ABC): "save_safetensors", "save_only_model", "include_tokens_per_second", + "weight_decay", ]: if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: training_args_kwargs[arg] = getattr(self.cfg, arg) @@ -601,10 +602,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups - training_arguments_kwargs["weight_decay"] = ( - self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 - ) - training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing) training_arguments_kwargs["multipack_real_batches"] = ( not self.cfg.flash_attention or self.cfg.multipack_real_batches @@ -1004,6 +1001,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): else: training_args_kwargs["remove_unused_columns"] = False + # only rlhf if self.cfg.dataset_processes: training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes