load the class from strat
This commit is contained in:
@@ -1048,7 +1048,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer_kwargs[
|
||||
"precompute_ref_log_probs"
|
||||
] = self.cfg.precompute_ref_log_probs
|
||||
if self.cfg.rl in ["dpo", "ipo"]:
|
||||
if self.cfg.rl == "grpo":
|
||||
trainer_cls = GRPOStrategy.get_trainer_class()
|
||||
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
elif self.cfg.rl in ["dpo", "ipo"]:
|
||||
trainer_cls = DPOStrategy.get_trainer_class()
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
elif self.cfg.rl == "orpo":
|
||||
|
||||
@@ -48,13 +48,11 @@ class GRPOStrategy:
|
||||
reward_func = getattr(reward_func_module, reward_func_module_name)
|
||||
reward_funcs.append(reward_func)
|
||||
trainer_kwargs["reward_funcs"] = reward_funcs
|
||||
|
||||
trainer_kwargs["data_collator"] = cls.get_collator(cfg)
|
||||
return trainer_kwargs
|
||||
|
||||
@classmethod
|
||||
def get_collator(
|
||||
cls, cfg, training_args, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument
|
||||
def data_collator(features): # No data collation is needed in GRPO
|
||||
return features
|
||||
|
||||
|
||||
Reference in New Issue
Block a user