From c268a0157a4ce34b523d2ad9e618cfd479b799c1 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 16 Dec 2024 15:59:41 +0700 Subject: [PATCH] feat: add report_to to set run name --- src/axolotl/core/trainer_builder.py | 56 ++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 5ab62343a..da6384c8e 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -323,6 +323,24 @@ class TrainerBuilderBase(abc.ABC): total_num_steps if self.cfg.max_steps else -1 ) + report_to = [] + if self.cfg.use_wandb: + report_to.append("wandb") + if self.cfg.use_mlflow: + report_to.append("mlflow") + if self.cfg.use_tensorboard: + report_to.append("tensorboard") + if self.cfg.use_comet: + report_to.append("comet_ml") + + training_args_kwargs["report_to"] = report_to + if self.cfg.use_wandb: + training_args_kwargs["run_name"] = self.cfg.wandb_name + elif self.cfg.use_mlflow: + training_args_kwargs["run_name"] = self.cfg.mlflow_run_name + else: + training_args_kwargs["run_name"] = None + return training_args_kwargs @@ -548,23 +566,29 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling - report_to = [] - if self.cfg.use_wandb: - report_to.append("wandb") - if self.cfg.use_mlflow: - report_to.append("mlflow") - if self.cfg.use_tensorboard: - report_to.append("tensorboard") - if self.cfg.use_comet: - report_to.append("comet_ml") - training_arguments_kwargs["report_to"] = report_to - if self.cfg.use_wandb: - training_arguments_kwargs["run_name"] = self.cfg.wandb_name - elif self.cfg.use_mlflow: - training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name - else: - training_arguments_kwargs["run_name"] = None + training_arguments_kwargs["optim"] = ( + self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" + ) + if self.cfg.optim_args: + if isinstance(self.cfg.optim_args, dict): + optim_args = ",".join( + [f"{key}={value}" for key, value in self.cfg.optim_args.items()] + ) + else: + optim_args = self.cfg.optim_args + training_arguments_kwargs["optim_args"] = optim_args + if self.cfg.optim_target_modules: + training_arguments_kwargs[ + "optim_target_modules" + ] = self.cfg.optim_target_modules + training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio + training_arguments_kwargs[ + "loraplus_lr_embedding" + ] = self.cfg.loraplus_lr_embedding + training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr + training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale + training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]: training_arguments_kwargs["lr_scheduler_type"] = "cosine"