From b8f258817e11ae026a4b68490657d3e1e82b11e6 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 5 Feb 2025 13:30:02 +0000 Subject: [PATCH] adding reward fn verification --- src/axolotl/core/trainers/grpo/__init__.py | 50 ++++++++++++++++++---- 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index c81c44e66..b04373a95 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -1,10 +1,17 @@ """ GRPO Specific Strategy for training """ + import importlib +import inspect +import logging + +from trl.trainer.grpo_trainer import RewardFunc from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer +LOG = logging.getLogger("axolotl") + class GRPOStrategy: """ @@ -43,14 +50,8 @@ class GRPOStrategy: trainer_kwargs = {} if cfg.grpo_reward_funcs: reward_funcs = [] - for reward_func_module in cfg.grpo_reward_funcs: - # use importlib to dynamically load the reward function from the module - reward_func_module_name = reward_func_module.split(".")[-1] - reward_func_module = importlib.import_module( - reward_func_module.split(".")[-2] - ) - reward_func = getattr(reward_func_module, reward_func_module_name) - reward_funcs.append(reward_func) + for reward_func_fqn in cfg.grpo_reward_funcs: + reward_funcs.append(cls.get_reward_func(reward_func_fqn)) trainer_kwargs["reward_funcs"] = reward_funcs return trainer_kwargs @@ -62,3 +63,36 @@ class GRPOStrategy: @classmethod def get_blocklist_args_kwargs(cls): return ["dataset_num_proc"] + + @classmethod + def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc: + """ + Returns the reward function from the given fully qualified name, or the path to the reward function model. + + Args: + reward_func_fqn (str): Fully qualified name of the reward function (e.g. r1_grpo.gsm8k_transform), + or a HF hub path to the reward model. + Raises: + ValueError: If the reward function does not accept at least two arguments. + + Returns: + RewardFunc: A callable that accepts prompts and completions and returns rewards, + or a path to a reward model. + + """ + try: + # use importlib to dynamically load the reward function from the module + reward_func_module_name = reward_func_fqn.split(".")[-1] + reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2]) + reward_func = getattr(reward_func_module, reward_func_module_name) + if not len(inspect.signature(reward_func).parameters) >= 2: + raise ValueError( + "Reward function must accept at least two arguments: prompts: list and completions: list" + ) + return reward_func + except ModuleNotFoundError: + # the user has passed a string (ideally indicating the path of a reward model) + LOG.info( + f"Reward function {reward_func} is a pre-trained model path - if this is unexpected, please check the reward function path." + ) + return reward_func