fix support for wandb run_name for rl trainers (#2566) [skip ci]
* fix support for wandb run_name for rl trainers * prefer to use wandb random names for run_name
This commit is contained in:
@@ -1048,6 +1048,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.rpo_alpha is not None:
|
if self.cfg.rpo_alpha is not None:
|
||||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||||
|
|
||||||
|
if self.cfg.use_wandb:
|
||||||
|
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||||
|
|
||||||
training_args_cls = None
|
training_args_cls = None
|
||||||
blocklist_args_kwargs = []
|
blocklist_args_kwargs = []
|
||||||
if self.cfg.rl == "simpo":
|
if self.cfg.rl == "simpo":
|
||||||
@@ -1118,6 +1121,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
**training_args_kwargs,
|
**training_args_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# unset run_name so wandb sets up experiment names
|
||||||
|
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
|
||||||
|
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
None
|
||||||
|
)
|
||||||
|
|
||||||
return training_args
|
return training_args
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
|
|||||||
Reference in New Issue
Block a user