diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index b04373a95..0a7a22f2d 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -53,6 +53,10 @@ class GRPOStrategy: for reward_func_fqn in cfg.grpo_reward_funcs: reward_funcs.append(cls.get_reward_func(reward_func_fqn)) trainer_kwargs["reward_funcs"] = reward_funcs + if cfg.grpo_reward_processing_classes: + trainer_kwargs[ + "reward_processing_classes" + ] = cfg.grpo_reward_processing_classes return trainer_kwargs @classmethod