more fixes to get grpo working

This commit is contained in:
Wing Lian
2025-02-03 08:32:44 -05:00
parent cfc7fe0df2
commit 1e94d7ef65
2 changed files with 11 additions and 12 deletions

View File

@@ -797,9 +797,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
def build_collator(
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if self.cfg.rl == "grpo":
return GRPOStrategy.get_collator(self.cfg, training_args, **kwargs)
if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
@@ -982,6 +979,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
@@ -1014,14 +1012,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
else:
training_args_cls = DPOConfig.get_training_args_class()
training_args_kwargs.update(DPOConfig.set_training_args_kwargs(self.cfg))
for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
self.cfg.output_dir,
per_device_train_batch_size=self.cfg.micro_batch_size,
@@ -1054,6 +1053,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
] = self.cfg.precompute_ref_log_probs
if self.cfg.rl == "grpo":
trainer_cls = GRPOStrategy.get_trainer_class()
trainer_cls_args = [self.model]
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = DPOStrategy.get_trainer_class()

View File

@@ -44,19 +44,18 @@ class GRPOStrategy:
for reward_func_module in cfg.grpo_reward_funcs:
# use importlib to dynamically load the reward function from the module
reward_func_module_name = reward_func_module.split(".")[-1]
reward_func_module = importlib.import_module(reward_func_module)
reward_func_module = importlib.import_module(
reward_func_module.split(".")[-2]
)
reward_func = getattr(reward_func_module, reward_func_module_name)
reward_funcs.append(reward_func)
trainer_kwargs["reward_funcs"] = reward_funcs
trainer_kwargs["data_collator"] = cls.get_collator(cfg)
return trainer_kwargs
@classmethod
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument
def data_collator(features): # No data collation is needed in GRPO
return features
return data_collator
# No data collation is needed in GRPO, handled by trl's trainer __init__
return None
@classmethod
def get_blocklist_args_kwargs(cls):