add support for num_generations
This commit is contained in:
@@ -34,6 +34,8 @@ class GRPOStrategy:
|
|||||||
grpo_args_kwargs[
|
grpo_args_kwargs[
|
||||||
"vllm_gpu_memory_utilization"
|
"vllm_gpu_memory_utilization"
|
||||||
] = cfg.grpo_vllm_gpu_memory_utilization
|
] = cfg.grpo_vllm_gpu_memory_utilization
|
||||||
|
if cfg.grpo_num_generations:
|
||||||
|
grpo_args_kwargs["num_generations"] = cfg.grpo_num_generations
|
||||||
return grpo_args_kwargs
|
return grpo_args_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -15,3 +15,4 @@ class GRPOConfig(BaseModel):
|
|||||||
grpo_vllm_device: Optional[str] = "auto"
|
grpo_vllm_device: Optional[str] = "auto"
|
||||||
grpo_vllm_gpu_memory_utilization: Optional[float] = 0.9
|
grpo_vllm_gpu_memory_utilization: Optional[float] = 0.9
|
||||||
grpo_reward_funcs: Optional[List[str]] = None
|
grpo_reward_funcs: Optional[List[str]] = None
|
||||||
|
grpo_num_generations: Optional[int] = None
|
||||||
|
|||||||
Reference in New Issue
Block a user