From ac1ebc58a8f8c053bd3864524d0193165277c9fc Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 3 Feb 2025 22:10:32 -0500 Subject: [PATCH] add support for num_generations --- src/axolotl/core/trainers/grpo/__init__.py | 2 ++ src/axolotl/utils/config/models/input/v0_4_1/grpo.py | 1 + 2 files changed, 3 insertions(+) diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 07db1dc46..c81c44e66 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -34,6 +34,8 @@ class GRPOStrategy: grpo_args_kwargs[ "vllm_gpu_memory_utilization" ] = cfg.grpo_vllm_gpu_memory_utilization + if cfg.grpo_num_generations: + grpo_args_kwargs["num_generations"] = cfg.grpo_num_generations return grpo_args_kwargs @classmethod diff --git a/src/axolotl/utils/config/models/input/v0_4_1/grpo.py b/src/axolotl/utils/config/models/input/v0_4_1/grpo.py index 857d93cf9..39ac341c6 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/grpo.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/grpo.py @@ -15,3 +15,4 @@ class GRPOConfig(BaseModel): grpo_vllm_device: Optional[str] = "auto" grpo_vllm_gpu_memory_utilization: Optional[float] = 0.9 grpo_reward_funcs: Optional[List[str]] = None + grpo_num_generations: Optional[int] = None