adding reward fn verification
This commit is contained in:
@@ -1,10 +1,17 @@
|
||||
"""
|
||||
GRPO Specific Strategy for training
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
class GRPOStrategy:
|
||||
"""
|
||||
@@ -43,14 +50,8 @@ class GRPOStrategy:
|
||||
trainer_kwargs = {}
|
||||
if cfg.grpo_reward_funcs:
|
||||
reward_funcs = []
|
||||
for reward_func_module in cfg.grpo_reward_funcs:
|
||||
# use importlib to dynamically load the reward function from the module
|
||||
reward_func_module_name = reward_func_module.split(".")[-1]
|
||||
reward_func_module = importlib.import_module(
|
||||
reward_func_module.split(".")[-2]
|
||||
)
|
||||
reward_func = getattr(reward_func_module, reward_func_module_name)
|
||||
reward_funcs.append(reward_func)
|
||||
for reward_func_fqn in cfg.grpo_reward_funcs:
|
||||
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
|
||||
trainer_kwargs["reward_funcs"] = reward_funcs
|
||||
return trainer_kwargs
|
||||
|
||||
@@ -62,3 +63,36 @@ class GRPOStrategy:
|
||||
@classmethod
|
||||
def get_blocklist_args_kwargs(cls):
|
||||
return ["dataset_num_proc"]
|
||||
|
||||
@classmethod
|
||||
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
|
||||
"""
|
||||
Returns the reward function from the given fully qualified name, or the path to the reward function model.
|
||||
|
||||
Args:
|
||||
reward_func_fqn (str): Fully qualified name of the reward function (e.g. r1_grpo.gsm8k_transform),
|
||||
or a HF hub path to the reward model.
|
||||
Raises:
|
||||
ValueError: If the reward function does not accept at least two arguments.
|
||||
|
||||
Returns:
|
||||
RewardFunc: A callable that accepts prompts and completions and returns rewards,
|
||||
or a path to a reward model.
|
||||
|
||||
"""
|
||||
try:
|
||||
# use importlib to dynamically load the reward function from the module
|
||||
reward_func_module_name = reward_func_fqn.split(".")[-1]
|
||||
reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2])
|
||||
reward_func = getattr(reward_func_module, reward_func_module_name)
|
||||
if not len(inspect.signature(reward_func).parameters) >= 2:
|
||||
raise ValueError(
|
||||
"Reward function must accept at least two arguments: prompts: list and completions: list"
|
||||
)
|
||||
return reward_func
|
||||
except ModuleNotFoundError:
|
||||
# the user has passed a string (ideally indicating the path of a reward model)
|
||||
LOG.info(
|
||||
f"Reward function {reward_func} is a pre-trained model path - if this is unexpected, please check the reward function path."
|
||||
)
|
||||
return reward_func
|
||||
|
||||
Reference in New Issue
Block a user