feat: add report_to to set run name
This commit is contained in:
@@ -323,6 +323,24 @@ class TrainerBuilderBase(abc.ABC):
|
||||
total_num_steps if self.cfg.max_steps else -1
|
||||
)
|
||||
|
||||
report_to = []
|
||||
if self.cfg.use_wandb:
|
||||
report_to.append("wandb")
|
||||
if self.cfg.use_mlflow:
|
||||
report_to.append("mlflow")
|
||||
if self.cfg.use_tensorboard:
|
||||
report_to.append("tensorboard")
|
||||
if self.cfg.use_comet:
|
||||
report_to.append("comet_ml")
|
||||
|
||||
training_args_kwargs["report_to"] = report_to
|
||||
if self.cfg.use_wandb:
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
elif self.cfg.use_mlflow:
|
||||
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||
else:
|
||||
training_args_kwargs["run_name"] = None
|
||||
|
||||
return training_args_kwargs
|
||||
|
||||
|
||||
@@ -548,23 +566,29 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||
report_to = []
|
||||
if self.cfg.use_wandb:
|
||||
report_to.append("wandb")
|
||||
if self.cfg.use_mlflow:
|
||||
report_to.append("mlflow")
|
||||
if self.cfg.use_tensorboard:
|
||||
report_to.append("tensorboard")
|
||||
if self.cfg.use_comet:
|
||||
report_to.append("comet_ml")
|
||||
|
||||
training_arguments_kwargs["report_to"] = report_to
|
||||
if self.cfg.use_wandb:
|
||||
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
|
||||
elif self.cfg.use_mlflow:
|
||||
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||
else:
|
||||
training_arguments_kwargs["run_name"] = None
|
||||
training_arguments_kwargs["optim"] = (
|
||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||
)
|
||||
if self.cfg.optim_args:
|
||||
if isinstance(self.cfg.optim_args, dict):
|
||||
optim_args = ",".join(
|
||||
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
||||
)
|
||||
else:
|
||||
optim_args = self.cfg.optim_args
|
||||
training_arguments_kwargs["optim_args"] = optim_args
|
||||
if self.cfg.optim_target_modules:
|
||||
training_arguments_kwargs[
|
||||
"optim_target_modules"
|
||||
] = self.cfg.optim_target_modules
|
||||
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||
training_arguments_kwargs[
|
||||
"loraplus_lr_embedding"
|
||||
] = self.cfg.loraplus_lr_embedding
|
||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
|
||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||
|
||||
Reference in New Issue
Block a user