From 79472241e8df0cf2d4fa43b8fec186e24d9ddc8e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 22 May 2025 18:39:33 +0700 Subject: [PATCH] chore: refactor set_base_training_args into smaller modules --- src/axolotl/core/trainer_builder/base.py | 276 ++++++++++++++--------- 1 file changed, 164 insertions(+), 112 deletions(-) diff --git a/src/axolotl/core/trainer_builder/base.py b/src/axolotl/core/trainer_builder/base.py index bd6dd0cc1..cf78b7186 100644 --- a/src/axolotl/core/trainer_builder/base.py +++ b/src/axolotl/core/trainer_builder/base.py @@ -178,8 +178,8 @@ class TrainerBuilderBase(abc.ABC): # TODO return trainer - def _set_base_training_args(self, total_num_steps) -> dict[str, Any]: - training_args_kwargs: Dict[str, Any] = {} + def _configure_warmup_and_logging(self, total_num_steps): + training_args_kwargs = {} warmup_steps = 0 warmup_ratio = 0.0 @@ -212,7 +212,11 @@ class TrainerBuilderBase(abc.ABC): training_args_kwargs["warmup_steps"] = warmup_steps training_args_kwargs["logging_steps"] = logging_steps - # precision + return training_args_kwargs + + def _configure_precision_settings(self): + training_args_kwargs = {} + training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False training_args_kwargs["tf32"] = self.cfg.tf32 if self.cfg.bf16 == "full": @@ -220,116 +224,11 @@ class TrainerBuilderBase(abc.ABC): else: training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16 - # hub - if self.cfg.hub_model_id: - training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id - training_args_kwargs["push_to_hub"] = True - training_args_kwargs["hub_private_repo"] = True - training_args_kwargs["hub_always_push"] = True + return training_args_kwargs - if self.cfg.hub_strategy: - training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy + def _configure_optimizer_and_scheduler(self): + training_args_kwargs = {} - # save_strategy and save_steps - if self.cfg.save_steps: - training_args_kwargs["save_strategy"] = "steps" - training_args_kwargs["save_steps"] = self.cfg.save_steps - elif self.cfg.save_strategy: - training_args_kwargs["save_strategy"] = self.cfg.save_strategy - else: - # default to saving each epoch if not defined - training_args_kwargs["save_strategy"] = "epoch" - - # eval_strategy and eval_steps - if not self.eval_dataset or self.cfg.val_set_size == 0: - # do not eval if no eval_dataset or val_set_size=0 - training_args_kwargs["eval_strategy"] = "no" - elif self.cfg.eval_steps: - training_args_kwargs["eval_strategy"] = "steps" - training_args_kwargs["eval_steps"] = self.cfg.eval_steps - elif self.cfg.eval_strategy: - training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy - - if self.cfg.gradient_checkpointing: - training_args_kwargs["gradient_checkpointing"] = ( - self.cfg.gradient_checkpointing - ) - if self.cfg.gradient_checkpointing_kwargs is not None: - training_args_kwargs["gradient_checkpointing_kwargs"] = ( - self.cfg.gradient_checkpointing_kwargs - ) - else: - training_args_kwargs["gradient_checkpointing_kwargs"] = { - "use_reentrant": False - } - - # set arg into trainer_args_kwargs with same name if value not None - for arg in [ - "adam_beta1", - "adam_beta2", - "adam_epsilon", - "max_grad_norm", - "dataloader_num_workers", - "dataloader_pin_memory", - "dataloader_prefetch_factor", - "gradient_accumulation_steps", - "learning_rate", - "embedding_lr", - "embedding_lr_scale", - "lr_groups", - "loraplus_lr_ratio", - "loraplus_lr_embedding", - "output_dir", - "save_safetensors", - "save_only_model", - "include_tokens_per_second", - "weight_decay", - "sequence_parallel_degree", - "ring_attn_func", - "seed", - ]: - if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: - training_args_kwargs[arg] = getattr(self.cfg, arg) - - training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size - - if self.cfg.eval_batch_size: - training_args_kwargs["per_device_eval_batch_size"] = ( - self.cfg.eval_batch_size - ) - - training_args_kwargs["save_total_limit"] = ( - self.cfg.save_total_limit if self.cfg.save_total_limit else 4 - ) - - training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1 - - training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs - - # max_length is not used in CausalTrainer - if self.cfg.reward_model or self.cfg.rl: - training_args_kwargs["max_length"] = self.cfg.sequence_len - - # reporting - report_to = [] - if self.cfg.use_wandb: - report_to.append("wandb") - if self.cfg.use_mlflow: - report_to.append("mlflow") - if self.cfg.use_tensorboard: - report_to.append("tensorboard") - if self.cfg.use_comet: - report_to.append("comet_ml") - - training_args_kwargs["report_to"] = report_to - if self.cfg.use_wandb: - training_args_kwargs["run_name"] = self.cfg.wandb_name - elif self.cfg.use_mlflow: - training_args_kwargs["run_name"] = self.cfg.mlflow_run_name - else: - training_args_kwargs["run_name"] = None - - # optim/scheduler if self.cfg.lr_scheduler in ["one_cycle", "log_sweep", "rex"]: training_args_kwargs["lr_scheduler_type"] = "cosine" training_args_kwargs["alternate_lr_scheduler_type"] = self.cfg.lr_scheduler @@ -462,7 +361,78 @@ class TrainerBuilderBase(abc.ABC): if self.cfg.optim_target_modules: training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules - # torch compile + return training_args_kwargs + + def _configure_hub_parameters(self): + training_args_kwargs = {} + + if self.cfg.hub_model_id: + training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id + training_args_kwargs["push_to_hub"] = True + training_args_kwargs["hub_private_repo"] = True + training_args_kwargs["hub_always_push"] = True + + if self.cfg.hub_strategy: + training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy + + return training_args_kwargs + + def _configure_save_and_eval_strategy(self): + training_args_kwargs = {} + + # save_strategy and save_steps + if self.cfg.save_steps: + training_args_kwargs["save_strategy"] = "steps" + training_args_kwargs["save_steps"] = self.cfg.save_steps + elif self.cfg.save_strategy: + training_args_kwargs["save_strategy"] = self.cfg.save_strategy + else: + # default to saving each epoch if not defined + training_args_kwargs["save_strategy"] = "epoch" + + training_args_kwargs["save_total_limit"] = ( + self.cfg.save_total_limit if self.cfg.save_total_limit else 4 + ) + + # eval_strategy and eval_steps + if not self.eval_dataset or self.cfg.val_set_size == 0: + # do not eval if no eval_dataset or val_set_size=0 + training_args_kwargs["eval_strategy"] = "no" + elif self.cfg.eval_steps: + training_args_kwargs["eval_strategy"] = "steps" + training_args_kwargs["eval_steps"] = self.cfg.eval_steps + elif self.cfg.eval_strategy: + training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy + + return training_args_kwargs + + def _configure_reporting(self): + training_args_kwargs = {} + + report_to = [] + if self.cfg.use_wandb: + report_to.append("wandb") + if self.cfg.use_mlflow: + report_to.append("mlflow") + if self.cfg.use_tensorboard: + report_to.append("tensorboard") + if self.cfg.use_comet: + report_to.append("comet_ml") + + training_args_kwargs["report_to"] = report_to + + if self.cfg.use_wandb: + training_args_kwargs["run_name"] = self.cfg.wandb_name + elif self.cfg.use_mlflow: + training_args_kwargs["run_name"] = self.cfg.mlflow_run_name + else: + training_args_kwargs["run_name"] = None + + return training_args_kwargs + + def _configure_torch_compile(self): + training_args_kwargs = {} + if self.cfg.torch_compile and getattr(torch, "_dynamo", None): torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access True @@ -476,3 +446,85 @@ class TrainerBuilderBase(abc.ABC): training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode return training_args_kwargs + + def _configure_gradient_checkpointing(self): + training_args_kwargs = {} + + if self.cfg.gradient_checkpointing: + training_args_kwargs["gradient_checkpointing"] = ( + self.cfg.gradient_checkpointing + ) + if self.cfg.gradient_checkpointing_kwargs is not None: + training_args_kwargs["gradient_checkpointing_kwargs"] = ( + self.cfg.gradient_checkpointing_kwargs + ) + else: + training_args_kwargs["gradient_checkpointing_kwargs"] = { + "use_reentrant": False + } + + return training_args_kwargs + + def _set_base_training_args(self, total_num_steps) -> dict[str, Any]: + training_args_kwargs: Dict[str, Any] = {} + + training_args_kwargs.update(self._configure_warmup_and_logging(total_num_steps)) + + training_args_kwargs.update(self._configure_precision_settings()) + + training_args_kwargs.update(self._configure_save_and_eval_strategy()) + + training_args_kwargs.update(self._configure_gradient_checkpointing()) + + # set arg into trainer_args_kwargs with same name if value not None + for arg in [ + "adam_beta1", + "adam_beta2", + "adam_epsilon", + "max_grad_norm", + "dataloader_num_workers", + "dataloader_pin_memory", + "dataloader_prefetch_factor", + "gradient_accumulation_steps", + "learning_rate", + "embedding_lr", + "embedding_lr_scale", + "lr_groups", + "loraplus_lr_ratio", + "loraplus_lr_embedding", + "output_dir", + "save_safetensors", + "save_only_model", + "include_tokens_per_second", + "weight_decay", + "sequence_parallel_degree", + "ring_attn_func", + "seed", + ]: + if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: + training_args_kwargs[arg] = getattr(self.cfg, arg) + + training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size + + if self.cfg.eval_batch_size: + training_args_kwargs["per_device_eval_batch_size"] = ( + self.cfg.eval_batch_size + ) + + training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1 + + training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs + + # max_length is not used in CausalTrainer + if self.cfg.reward_model or self.cfg.rl: + training_args_kwargs["max_length"] = self.cfg.sequence_len + + training_args_kwargs.update(self._configure_reporting()) + + training_args_kwargs.update(self._configure_hub_parameters()) + + training_args_kwargs.update(self._configure_optimizer_and_scheduler()) + + training_args_kwargs.update(self._configure_torch_compile()) + + return training_args_kwargs