From a4ee56c3156f334c6eb3ff3a00b7f3f4efe5888f Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Tue, 10 Feb 2026 16:36:15 +0530 Subject: [PATCH] fix: set rollout in GRPO training_kwargs (#3392) --- src/axolotl/core/trainers/grpo/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 8b295e537..0d2615aec 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -126,9 +126,6 @@ class GRPOStrategy: if trl.use_liger_loss is not None: grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss - if trl.rollout_func: - grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func) - if trl.multi_objective_aggregation is not None: grpo_args_kwargs["multi_objective_aggregation"] = ( trl.multi_objective_aggregation @@ -154,6 +151,8 @@ class GRPOStrategy: trainer_kwargs["reward_processing_classes"] = ( cfg.trl.reward_processing_classes ) + if cfg.trl and cfg.trl.rollout_func: + trainer_kwargs["rollout_func"] = cls.get_rollout_func(cfg.trl.rollout_func) return trainer_kwargs