From 1e94d7ef6503c616f064e23fcfa894a1849b9ecc Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 3 Feb 2025 08:32:44 -0500 Subject: [PATCH] more fixes to get grpo working --- src/axolotl/core/trainer_builder.py | 12 ++++++------ src/axolotl/core/trainers/grpo/__init__.py | 11 +++++------ 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b95263f05..2a7f9ca21 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -797,9 +797,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): def build_collator( self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs ): - if self.cfg.rl == "grpo": - return GRPOStrategy.get_collator(self.cfg, training_args, **kwargs) - if training_args.pretraining: if self.cfg.pretraining_sample_concatenation is False: return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) @@ -982,6 +979,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha training_args_cls = None + blocklist_args_kwargs = [] if self.cfg.rl == "simpo": training_args_cls = AxolotlCPOConfig training_args_kwargs["loss_type"] = "simpo" @@ -1014,14 +1012,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase): training_args_cls = GRPOStrategy.get_training_args_class() training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() - for blocklist_key in blocklist_args_kwargs: - if blocklist_key in training_args_kwargs: - del training_args_kwargs[blocklist_key] else: training_args_cls = DPOConfig.get_training_args_class() training_args_kwargs.update(DPOConfig.set_training_args_kwargs(self.cfg)) + for blocklist_key in blocklist_args_kwargs: + if blocklist_key in training_args_kwargs: + del training_args_kwargs[blocklist_key] + training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg self.cfg.output_dir, per_device_train_batch_size=self.cfg.micro_batch_size, @@ -1054,6 +1053,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): ] = self.cfg.precompute_ref_log_probs if self.cfg.rl == "grpo": trainer_cls = GRPOStrategy.get_trainer_class() + trainer_cls_args = [self.model] dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg)) elif self.cfg.rl in ["dpo", "ipo"]: trainer_cls = DPOStrategy.get_trainer_class() diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 01e19976f..07db1dc46 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -44,19 +44,18 @@ class GRPOStrategy: for reward_func_module in cfg.grpo_reward_funcs: # use importlib to dynamically load the reward function from the module reward_func_module_name = reward_func_module.split(".")[-1] - reward_func_module = importlib.import_module(reward_func_module) + reward_func_module = importlib.import_module( + reward_func_module.split(".")[-2] + ) reward_func = getattr(reward_func_module, reward_func_module_name) reward_funcs.append(reward_func) trainer_kwargs["reward_funcs"] = reward_funcs - trainer_kwargs["data_collator"] = cls.get_collator(cfg) return trainer_kwargs @classmethod def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument - def data_collator(features): # No data collation is needed in GRPO - return features - - return data_collator + # No data collation is needed in GRPO, handled by trl's trainer __init__ + return None @classmethod def get_blocklist_args_kwargs(cls):