diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index d270ba00d..c5136ce0a 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1048,7 +1048,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase): dpo_trainer_kwargs[ "precompute_ref_log_probs" ] = self.cfg.precompute_ref_log_probs - if self.cfg.rl in ["dpo", "ipo"]: + if self.cfg.rl == "grpo": + trainer_cls = GRPOStrategy.get_trainer_class() + dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg)) + elif self.cfg.rl in ["dpo", "ipo"]: trainer_cls = DPOStrategy.get_trainer_class() trainer_cls_args = [self.model, self.model_ref] elif self.cfg.rl == "orpo": diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 520ecc3c2..9c2fa583b 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -48,13 +48,11 @@ class GRPOStrategy: reward_func = getattr(reward_func_module, reward_func_module_name) reward_funcs.append(reward_func) trainer_kwargs["reward_funcs"] = reward_funcs - + trainer_kwargs["data_collator"] = cls.get_collator(cfg) return trainer_kwargs @classmethod - def get_collator( - cls, cfg, training_args, **kwargs - ): # pylint: disable=unused-argument + def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument def data_collator(features): # No data collation is needed in GRPO return features