more fixes to get grpo working
This commit is contained in:
@@ -797,9 +797,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
def build_collator(
|
||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||
):
|
||||
if self.cfg.rl == "grpo":
|
||||
return GRPOStrategy.get_collator(self.cfg, training_args, **kwargs)
|
||||
|
||||
if training_args.pretraining:
|
||||
if self.cfg.pretraining_sample_concatenation is False:
|
||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||
@@ -982,6 +979,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||
|
||||
training_args_cls = None
|
||||
blocklist_args_kwargs = []
|
||||
if self.cfg.rl == "simpo":
|
||||
training_args_cls = AxolotlCPOConfig
|
||||
training_args_kwargs["loss_type"] = "simpo"
|
||||
@@ -1014,14 +1012,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||
for blocklist_key in blocklist_args_kwargs:
|
||||
if blocklist_key in training_args_kwargs:
|
||||
del training_args_kwargs[blocklist_key]
|
||||
|
||||
else:
|
||||
training_args_cls = DPOConfig.get_training_args_class()
|
||||
training_args_kwargs.update(DPOConfig.set_training_args_kwargs(self.cfg))
|
||||
|
||||
for blocklist_key in blocklist_args_kwargs:
|
||||
if blocklist_key in training_args_kwargs:
|
||||
del training_args_kwargs[blocklist_key]
|
||||
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
self.cfg.output_dir,
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
@@ -1054,6 +1053,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
] = self.cfg.precompute_ref_log_probs
|
||||
if self.cfg.rl == "grpo":
|
||||
trainer_cls = GRPOStrategy.get_trainer_class()
|
||||
trainer_cls_args = [self.model]
|
||||
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
elif self.cfg.rl in ["dpo", "ipo"]:
|
||||
trainer_cls = DPOStrategy.get_trainer_class()
|
||||
|
||||
@@ -44,19 +44,18 @@ class GRPOStrategy:
|
||||
for reward_func_module in cfg.grpo_reward_funcs:
|
||||
# use importlib to dynamically load the reward function from the module
|
||||
reward_func_module_name = reward_func_module.split(".")[-1]
|
||||
reward_func_module = importlib.import_module(reward_func_module)
|
||||
reward_func_module = importlib.import_module(
|
||||
reward_func_module.split(".")[-2]
|
||||
)
|
||||
reward_func = getattr(reward_func_module, reward_func_module_name)
|
||||
reward_funcs.append(reward_func)
|
||||
trainer_kwargs["reward_funcs"] = reward_funcs
|
||||
trainer_kwargs["data_collator"] = cls.get_collator(cfg)
|
||||
return trainer_kwargs
|
||||
|
||||
@classmethod
|
||||
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument
|
||||
def data_collator(features): # No data collation is needed in GRPO
|
||||
return features
|
||||
|
||||
return data_collator
|
||||
# No data collation is needed in GRPO, handled by trl's trainer __init__
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_blocklist_args_kwargs(cls):
|
||||
|
||||
Reference in New Issue
Block a user