adding reward fn verification

This commit is contained in:
Salman Mohammadi
2025-02-05 13:30:02 +00:00
parent 753146b458
commit b8f258817e

View File

@@ -1,10 +1,17 @@
"""
GRPO Specific Strategy for training
"""
import importlib
import inspect
import logging
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
LOG = logging.getLogger("axolotl")
class GRPOStrategy:
"""
@@ -43,14 +50,8 @@ class GRPOStrategy:
trainer_kwargs = {}
if cfg.grpo_reward_funcs:
reward_funcs = []
for reward_func_module in cfg.grpo_reward_funcs:
# use importlib to dynamically load the reward function from the module
reward_func_module_name = reward_func_module.split(".")[-1]
reward_func_module = importlib.import_module(
reward_func_module.split(".")[-2]
)
reward_func = getattr(reward_func_module, reward_func_module_name)
reward_funcs.append(reward_func)
for reward_func_fqn in cfg.grpo_reward_funcs:
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
trainer_kwargs["reward_funcs"] = reward_funcs
return trainer_kwargs
@@ -62,3 +63,36 @@ class GRPOStrategy:
@classmethod
def get_blocklist_args_kwargs(cls):
return ["dataset_num_proc"]
@classmethod
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
"""
Returns the reward function from the given fully qualified name, or the path to the reward function model.
Args:
reward_func_fqn (str): Fully qualified name of the reward function (e.g. r1_grpo.gsm8k_transform),
or a HF hub path to the reward model.
Raises:
ValueError: If the reward function does not accept at least two arguments.
Returns:
RewardFunc: A callable that accepts prompts and completions and returns rewards,
or a path to a reward model.
"""
try:
# use importlib to dynamically load the reward function from the module
reward_func_module_name = reward_func_fqn.split(".")[-1]
reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2])
reward_func = getattr(reward_func_module, reward_func_module_name)
if not len(inspect.signature(reward_func).parameters) >= 2:
raise ValueError(
"Reward function must accept at least two arguments: prompts: list and completions: list"
)
return reward_func
except ModuleNotFoundError:
# the user has passed a string (ideally indicating the path of a reward model)
LOG.info(
f"Reward function {reward_func} is a pre-trained model path - if this is unexpected, please check the reward function path."
)
return reward_func