fix support for wandb run_name for rl trainers (#2566) [skip ci]

* fix support for wandb run_name for rl trainers

* prefer to use wandb random names for run_name
This commit is contained in:
Wing Lian
2025-04-25 21:10:54 -04:00
committed by GitHub
parent e3c9d541a7
commit 5dba5c82a8

View File

@@ -1048,6 +1048,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl == "simpo":
@@ -1118,6 +1121,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
**training_args_kwargs,
)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
return training_args
def build(self, total_num_steps):