add support for num_generations

This commit is contained in:
Wing Lian
2025-02-03 22:10:32 -05:00
parent 56f3b9f20f
commit ac1ebc58a8
2 changed files with 3 additions and 0 deletions

View File

@@ -34,6 +34,8 @@ class GRPOStrategy:
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = cfg.grpo_vllm_gpu_memory_utilization
if cfg.grpo_num_generations:
grpo_args_kwargs["num_generations"] = cfg.grpo_num_generations
return grpo_args_kwargs
@classmethod

View File

@@ -15,3 +15,4 @@ class GRPOConfig(BaseModel):
grpo_vllm_device: Optional[str] = "auto"
grpo_vllm_gpu_memory_utilization: Optional[float] = 0.9
grpo_reward_funcs: Optional[List[str]] = None
grpo_num_generations: Optional[int] = None