set default on trl config
This commit is contained in:
@@ -969,7 +969,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.dataset_processes:
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
if self.cfg.trl.beta or self.cfg.rl_beta:
|
||||
if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
|
||||
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
|
||||
if self.cfg.orpo_alpha:
|
||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||
|
||||
@@ -33,25 +33,25 @@ class GRPOStrategy:
|
||||
grpo_args_kwargs = {}
|
||||
if cfg.trl and cfg.trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
|
||||
if cfg.trl.vllm_device:
|
||||
if cfg.trl and cfg.trl.vllm_device:
|
||||
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
|
||||
else:
|
||||
grpo_args_kwargs["vllm_device"] = "auto"
|
||||
if cfg.trl.vllm_gpu_memory_utilization:
|
||||
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
|
||||
grpo_args_kwargs[
|
||||
"vllm_gpu_memory_utilization"
|
||||
] = cfg.trl.vllm_gpu_memory_utilization
|
||||
if cfg.trl.vllm_max_model_len:
|
||||
if cfg.trl and cfg.trl.vllm_max_model_len:
|
||||
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
|
||||
if cfg.trl and cfg.trl.num_generations:
|
||||
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
|
||||
if cfg.trl and cfg.trl.sync_ref_model:
|
||||
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
|
||||
if cfg.trl.ref_model_mixup_alpha:
|
||||
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
|
||||
grpo_args_kwargs[
|
||||
"ref_model_mixup_alpha"
|
||||
] = cfg.trl.ref_model_mixup_alpha
|
||||
if cfg.trl.ref_model_sync_steps:
|
||||
if cfg.trl and cfg.trl.ref_model_sync_steps:
|
||||
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
|
||||
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
|
||||
return grpo_args_kwargs
|
||||
|
||||
@@ -671,7 +671,7 @@ class AxolotlInputConfig(
|
||||
|
||||
rl: Optional[RLType] = None
|
||||
trl: Optional[TrlConfig] = Field(
|
||||
default_factory=lambda: TrlConfig() # pylint: disable=unnecessary-lambda
|
||||
default_factory=lambda: TrlConfig(), # pylint: disable=unnecessary-lambda
|
||||
)
|
||||
reward_model: Optional[bool] = None
|
||||
process_reward_model: Optional[bool] = None
|
||||
|
||||
Reference in New Issue
Block a user