From 626db6cb8459ac1ae32eb5e4f66a40df0f0770b3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 3 Feb 2025 00:18:17 -0500 Subject: [PATCH] collator for grpo and prompt loader --- src/axolotl/core/trainer_builder.py | 3 +++ src/axolotl/core/trainers/grpo/__init__.py | 9 +++++++++ src/axolotl/prompt_strategies/base.py | 13 +++++++------ 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 19481d22c..d270ba00d 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -797,6 +797,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): def build_collator( self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs ): + if self.cfg.rl == "grpo": + return GRPOStrategy.get_collator(self.cfg, training_args, **kwargs) + if training_args.pretraining: if self.cfg.pretraining_sample_concatenation is False: return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index bab8605e8..520ecc3c2 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -50,3 +50,12 @@ class GRPOStrategy: trainer_kwargs["reward_funcs"] = reward_funcs return trainer_kwargs + + @classmethod + def get_collator( + cls, cfg, training_args, **kwargs + ): # pylint: disable=unused-argument + def data_collator(features): # No data collation is needed in GRPO + return features + + return data_collator diff --git a/src/axolotl/prompt_strategies/base.py b/src/axolotl/prompt_strategies/base.py index ae5f85450..b3f2413ba 100644 --- a/src/axolotl/prompt_strategies/base.py +++ b/src/axolotl/prompt_strategies/base.py @@ -13,18 +13,19 @@ def load(strategy, cfg, module_base=None, **kwargs): if len(strategy.split(".")) == 1: strategy = strategy + ".default" load_fn = strategy.split(".")[-1] - strategy = ".".join(strategy.split(".")[:-1]) if len(strategy.split(".")) > 1: try: importlib.import_module( - "." + strategy.split(".")[-1], - ".".join(strategy.split(".")[:-1]), + strategy.split(".")[-2], + ".".join(strategy.split(".")[:-2]), ) - module_base = ".".join(strategy.split(".")[:-1]) - strategy = strategy.split(".")[-1] + module_base = ".".join(strategy.split(".")[:-2]) + strategy = strategy.split(".")[-2] except ModuleNotFoundError: pass - mod = importlib.import_module(f".{strategy}", module_base) + else: + strategy = "." + ".".join(strategy.split(".")[:-1]) + mod = importlib.import_module(strategy, module_base) func = getattr(mod, load_fn) return func(cfg, **kwargs) except Exception: # pylint: disable=broad-exception-caught