Fix bug in grpo reward module import (#2571)

This commit is contained in:
Dhruv Mullick
2025-04-27 22:31:56 -06:00
committed by GitHub
parent dc4da4a7e2
commit 8b33ae1c4f

View File

@@ -135,7 +135,9 @@ class GRPOStrategy:
try:
# use importlib to dynamically load the reward function from the module
reward_func_module_name = reward_func_fqn.split(".")[-1]
reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2])
reward_func_module = importlib.import_module(
".".join(reward_func_fqn.split(".")[:-1])
)
reward_func = getattr(reward_func_module, reward_func_module_name)
if not len(inspect.signature(reward_func).parameters) >= 2:
raise ValueError(