diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 8b295e537..0d2615aec 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -126,9 +126,6 @@ class GRPOStrategy: if trl.use_liger_loss is not None: grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss - if trl.rollout_func: - grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func) - if trl.multi_objective_aggregation is not None: grpo_args_kwargs["multi_objective_aggregation"] = ( trl.multi_objective_aggregation @@ -154,6 +151,8 @@ class GRPOStrategy: trainer_kwargs["reward_processing_classes"] = ( cfg.trl.reward_processing_classes ) + if cfg.trl and cfg.trl.rollout_func: + trainer_kwargs["rollout_func"] = cls.get_rollout_func(cfg.trl.rollout_func) return trainer_kwargs