set default on trl config

This commit is contained in:
Wing Lian
2025-02-05 22:17:10 -05:00
parent aded9c500d
commit 756a8332d6
3 changed files with 7 additions and 7 deletions

View File

@@ -969,7 +969,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dataset_processes: if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.trl.beta or self.cfg.rl_beta: if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
if self.cfg.orpo_alpha: if self.cfg.orpo_alpha:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ??? # trl does some odd mapping of alpha to beta to reuse the beta parameter ???

View File

@@ -33,25 +33,25 @@ class GRPOStrategy:
grpo_args_kwargs = {} grpo_args_kwargs = {}
if cfg.trl and cfg.trl.use_vllm: if cfg.trl and cfg.trl.use_vllm:
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
if cfg.trl.vllm_device: if cfg.trl and cfg.trl.vllm_device:
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
else: else:
grpo_args_kwargs["vllm_device"] = "auto" grpo_args_kwargs["vllm_device"] = "auto"
if cfg.trl.vllm_gpu_memory_utilization: if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[ grpo_args_kwargs[
"vllm_gpu_memory_utilization" "vllm_gpu_memory_utilization"
] = cfg.trl.vllm_gpu_memory_utilization ] = cfg.trl.vllm_gpu_memory_utilization
if cfg.trl.vllm_max_model_len: if cfg.trl and cfg.trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
if cfg.trl and cfg.trl.num_generations: if cfg.trl and cfg.trl.num_generations:
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
if cfg.trl and cfg.trl.sync_ref_model: if cfg.trl and cfg.trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
if cfg.trl.ref_model_mixup_alpha: if cfg.trl and cfg.trl.ref_model_mixup_alpha:
grpo_args_kwargs[ grpo_args_kwargs[
"ref_model_mixup_alpha" "ref_model_mixup_alpha"
] = cfg.trl.ref_model_mixup_alpha ] = cfg.trl.ref_model_mixup_alpha
if cfg.trl.ref_model_sync_steps: if cfg.trl and cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
return grpo_args_kwargs return grpo_args_kwargs

View File

@@ -671,7 +671,7 @@ class AxolotlInputConfig(
rl: Optional[RLType] = None rl: Optional[RLType] = None
trl: Optional[TrlConfig] = Field( trl: Optional[TrlConfig] = Field(
default_factory=lambda: TrlConfig() # pylint: disable=unnecessary-lambda default_factory=lambda: TrlConfig(), # pylint: disable=unnecessary-lambda
) )
reward_model: Optional[bool] = None reward_model: Optional[bool] = None
process_reward_model: Optional[bool] = None process_reward_model: Optional[bool] = None