better handling for reward function checks for GRPO (#2933) [skip ci]

* better handling for reward function checks for GRPO

* consolidate msg copy
This commit is contained in:
Wing Lian
2025-07-21 11:41:15 -04:00
committed by GitHub
parent af8d257aa2
commit fefb0797ee

View File

@@ -2,8 +2,11 @@
import importlib
import inspect
import os
from typing import Any
from huggingface_hub import snapshot_download
from requests import HTTPError
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
@@ -178,9 +181,18 @@ class GRPOStrategy:
"Reward function must accept at least two arguments: prompts: list and completions: list"
)
return reward_func
except ModuleNotFoundError:
except ModuleNotFoundError as exc:
# the user has passed a string (ideally indicating the path of a reward model)
LOG.info(
f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path."
)
return reward_func_fqn
# check if it's a local dir path and not empty dir to a reward model
pretrained_log_msg = f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path."
if os.path.isdir(reward_func_fqn) and os.listdir(reward_func_fqn):
LOG.info(pretrained_log_msg)
return reward_func_fqn
try:
snapshot_download(reward_func_fqn, repo_type="model")
LOG.info(pretrained_log_msg)
return reward_func_fqn
except HTTPError:
raise ValueError(
f"Reward function {reward_func_fqn} not found."
) from exc