adding reward fn verification
This commit is contained in:
@@ -1,10 +1,17 @@
|
|||||||
"""
|
"""
|
||||||
GRPO Specific Strategy for training
|
GRPO Specific Strategy for training
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from trl.trainer.grpo_trainer import RewardFunc
|
||||||
|
|
||||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
class GRPOStrategy:
|
class GRPOStrategy:
|
||||||
"""
|
"""
|
||||||
@@ -43,14 +50,8 @@ class GRPOStrategy:
|
|||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
if cfg.grpo_reward_funcs:
|
if cfg.grpo_reward_funcs:
|
||||||
reward_funcs = []
|
reward_funcs = []
|
||||||
for reward_func_module in cfg.grpo_reward_funcs:
|
for reward_func_fqn in cfg.grpo_reward_funcs:
|
||||||
# use importlib to dynamically load the reward function from the module
|
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
|
||||||
reward_func_module_name = reward_func_module.split(".")[-1]
|
|
||||||
reward_func_module = importlib.import_module(
|
|
||||||
reward_func_module.split(".")[-2]
|
|
||||||
)
|
|
||||||
reward_func = getattr(reward_func_module, reward_func_module_name)
|
|
||||||
reward_funcs.append(reward_func)
|
|
||||||
trainer_kwargs["reward_funcs"] = reward_funcs
|
trainer_kwargs["reward_funcs"] = reward_funcs
|
||||||
return trainer_kwargs
|
return trainer_kwargs
|
||||||
|
|
||||||
@@ -62,3 +63,36 @@ class GRPOStrategy:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_blocklist_args_kwargs(cls):
|
def get_blocklist_args_kwargs(cls):
|
||||||
return ["dataset_num_proc"]
|
return ["dataset_num_proc"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
|
||||||
|
"""
|
||||||
|
Returns the reward function from the given fully qualified name, or the path to the reward function model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reward_func_fqn (str): Fully qualified name of the reward function (e.g. r1_grpo.gsm8k_transform),
|
||||||
|
or a HF hub path to the reward model.
|
||||||
|
Raises:
|
||||||
|
ValueError: If the reward function does not accept at least two arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RewardFunc: A callable that accepts prompts and completions and returns rewards,
|
||||||
|
or a path to a reward model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# use importlib to dynamically load the reward function from the module
|
||||||
|
reward_func_module_name = reward_func_fqn.split(".")[-1]
|
||||||
|
reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2])
|
||||||
|
reward_func = getattr(reward_func_module, reward_func_module_name)
|
||||||
|
if not len(inspect.signature(reward_func).parameters) >= 2:
|
||||||
|
raise ValueError(
|
||||||
|
"Reward function must accept at least two arguments: prompts: list and completions: list"
|
||||||
|
)
|
||||||
|
return reward_func
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# the user has passed a string (ideally indicating the path of a reward model)
|
||||||
|
LOG.info(
|
||||||
|
f"Reward function {reward_func} is a pre-trained model path - if this is unexpected, please check the reward function path."
|
||||||
|
)
|
||||||
|
return reward_func
|
||||||
|
|||||||
Reference in New Issue
Block a user