diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 07db1dc46..c81c44e66 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -34,6 +34,8 @@ class GRPOStrategy: grpo_args_kwargs[ "vllm_gpu_memory_utilization" ] = cfg.grpo_vllm_gpu_memory_utilization + if cfg.grpo_num_generations: + grpo_args_kwargs["num_generations"] = cfg.grpo_num_generations return grpo_args_kwargs @classmethod diff --git a/src/axolotl/utils/config/models/input/v0_4_1/grpo.py b/src/axolotl/utils/config/models/input/v0_4_1/grpo.py index 857d93cf9..39ac341c6 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/grpo.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/grpo.py @@ -15,3 +15,4 @@ class GRPOConfig(BaseModel): grpo_vllm_device: Optional[str] = "auto" grpo_vllm_gpu_memory_utilization: Optional[float] = 0.9 grpo_reward_funcs: Optional[List[str]] = None + grpo_num_generations: Optional[int] = None