From 8010376db966a31b4328d8eeb5b9f12523070852 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 22 May 2025 18:53:21 +0700 Subject: [PATCH] fix: change to pass dict via arg instead of updating dict --- src/axolotl/core/trainer_builder/base.py | 74 +++++++----------------- 1 file changed, 20 insertions(+), 54 deletions(-) diff --git a/src/axolotl/core/trainer_builder/base.py b/src/axolotl/core/trainer_builder/base.py index 5a785b4f8..8fbf1efe8 100644 --- a/src/axolotl/core/trainer_builder/base.py +++ b/src/axolotl/core/trainer_builder/base.py @@ -178,9 +178,7 @@ class TrainerBuilderBase(abc.ABC): # TODO return trainer - def _configure_warmup_and_logging(self, total_num_steps): - training_args_kwargs = {} - + def _configure_warmup_and_logging(self, total_num_steps, training_args_kwargs): warmup_steps = 0 warmup_ratio = 0.0 if self.cfg.warmup_steps: @@ -198,25 +196,19 @@ class TrainerBuilderBase(abc.ABC): if warmup_steps == 1: warmup_steps = 2 - logging_steps = ( - self.cfg.logging_steps - if self.cfg.logging_steps is not None - else ( + if self.cfg.logging_steps is not None: + training_args_kwargs["logging_steps"] = self.cfg.logging_steps + else: + training_args_kwargs["logging_steps"] = ( 500 # transformers defaults to 500 if not total_num_steps else max(min(int(0.005 * total_num_steps), 10), 1) ) - ) training_args_kwargs["warmup_ratio"] = warmup_ratio training_args_kwargs["warmup_steps"] = warmup_steps - training_args_kwargs["logging_steps"] = logging_steps - - return training_args_kwargs - - def _configure_precision_settings(self): - training_args_kwargs = {} + def _configure_precision_settings(self, training_args_kwargs): training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False training_args_kwargs["tf32"] = self.cfg.tf32 if self.cfg.bf16 == "full": @@ -224,11 +216,7 @@ class TrainerBuilderBase(abc.ABC): else: training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16 - return training_args_kwargs - - def _configure_optimizer_and_scheduler(self): - training_args_kwargs = {} - + def _configure_optimizer_and_scheduler(self, training_args_kwargs): if self.cfg.lr_scheduler in ["one_cycle", "log_sweep", "rex"]: training_args_kwargs["lr_scheduler_type"] = "cosine" training_args_kwargs["alternate_lr_scheduler_type"] = self.cfg.lr_scheduler @@ -361,11 +349,7 @@ class TrainerBuilderBase(abc.ABC): if self.cfg.optim_target_modules: training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules - return training_args_kwargs - - def _configure_hub_parameters(self): - training_args_kwargs = {} - + def _configure_hub_parameters(self, training_args_kwargs): if self.cfg.hub_model_id: training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id training_args_kwargs["push_to_hub"] = True @@ -375,11 +359,7 @@ class TrainerBuilderBase(abc.ABC): if self.cfg.hub_strategy: training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy - return training_args_kwargs - - def _configure_save_and_eval_strategy(self): - training_args_kwargs = {} - + def _configure_save_and_eval_strategy(self, training_args_kwargs): # save_strategy and save_steps if self.cfg.save_steps: training_args_kwargs["save_strategy"] = "steps" @@ -404,11 +384,7 @@ class TrainerBuilderBase(abc.ABC): elif self.cfg.eval_strategy: training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy - return training_args_kwargs - - def _configure_reporting(self): - training_args_kwargs = {} - + def _configure_reporting(self, training_args_kwargs): report_to = [] if self.cfg.use_wandb: report_to.append("wandb") @@ -428,11 +404,7 @@ class TrainerBuilderBase(abc.ABC): else: training_args_kwargs["run_name"] = None - return training_args_kwargs - - def _configure_torch_compile(self): - training_args_kwargs = {} - + def _configure_torch_compile(self, training_args_kwargs): if self.cfg.torch_compile and getattr(torch, "_dynamo", None): torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access True @@ -445,11 +417,7 @@ class TrainerBuilderBase(abc.ABC): if self.cfg.torch_compile_mode: training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode - return training_args_kwargs - - def _configure_gradient_checkpointing(self): - training_args_kwargs = {} - + def _configure_gradient_checkpointing(self, training_args_kwargs): if self.cfg.gradient_checkpointing: training_args_kwargs["gradient_checkpointing"] = ( self.cfg.gradient_checkpointing @@ -463,18 +431,16 @@ class TrainerBuilderBase(abc.ABC): "use_reentrant": False } - return training_args_kwargs - def _set_base_training_args(self, total_num_steps) -> dict[str, Any]: training_args_kwargs: Dict[str, Any] = {} - training_args_kwargs.update(self._configure_warmup_and_logging(total_num_steps)) + self._configure_warmup_and_logging(total_num_steps, training_args_kwargs) - training_args_kwargs.update(self._configure_precision_settings()) + self._configure_precision_settings(training_args_kwargs) - training_args_kwargs.update(self._configure_save_and_eval_strategy()) + self._configure_save_and_eval_strategy(training_args_kwargs) - training_args_kwargs.update(self._configure_gradient_checkpointing()) + self._configure_gradient_checkpointing(training_args_kwargs) # set arg into trainer_args_kwargs with same name if value not None for arg in [ @@ -521,12 +487,12 @@ class TrainerBuilderBase(abc.ABC): if self.cfg.reward_model or self.cfg.rl: training_args_kwargs["max_length"] = self.cfg.sequence_len - training_args_kwargs.update(self._configure_reporting()) + self._configure_reporting(training_args_kwargs) - training_args_kwargs.update(self._configure_hub_parameters()) + self._configure_hub_parameters(training_args_kwargs) - training_args_kwargs.update(self._configure_optimizer_and_scheduler()) + self._configure_optimizer_and_scheduler(training_args_kwargs) - training_args_kwargs.update(self._configure_torch_compile()) + self._configure_torch_compile(training_args_kwargs) return training_args_kwargs