load the class from strat

This commit is contained in:
Wing Lian
2025-02-03 00:31:55 -05:00
parent d155849e2c
commit b1c4711145
2 changed files with 6 additions and 5 deletions

View File

@@ -1048,7 +1048,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs[
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo"]:
if self.cfg.rl == "grpo":
trainer_cls = GRPOStrategy.get_trainer_class()
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args = [self.model, self.model_ref]
elif self.cfg.rl == "orpo":

View File

@@ -48,13 +48,11 @@ class GRPOStrategy:
reward_func = getattr(reward_func_module, reward_func_module_name)
reward_funcs.append(reward_func)
trainer_kwargs["reward_funcs"] = reward_funcs
trainer_kwargs["data_collator"] = cls.get_collator(cfg)
return trainer_kwargs
@classmethod
def get_collator(
cls, cfg, training_args, **kwargs
): # pylint: disable=unused-argument
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument
def data_collator(features): # No data collation is needed in GRPO
return features