From fefb0797ee2e6ded50394fd2048e62f19ce6b6c2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 21 Jul 2025 11:41:15 -0400 Subject: [PATCH] better handling for reward function checks for GRPO (#2933) [skip ci] * better handling for reward function checks for GRPO * consolidate msg copy --- src/axolotl/core/trainers/grpo/__init__.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 3e824f705..2c6eb8c6f 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -2,8 +2,11 @@ import importlib import inspect +import os from typing import Any +from huggingface_hub import snapshot_download +from requests import HTTPError from trl.trainer.grpo_trainer import RewardFunc from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig @@ -178,9 +181,18 @@ class GRPOStrategy: "Reward function must accept at least two arguments: prompts: list and completions: list" ) return reward_func - except ModuleNotFoundError: + except ModuleNotFoundError as exc: # the user has passed a string (ideally indicating the path of a reward model) - LOG.info( - f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path." - ) - return reward_func_fqn + # check if it's a local dir path and not empty dir to a reward model + pretrained_log_msg = f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path." + if os.path.isdir(reward_func_fqn) and os.listdir(reward_func_fqn): + LOG.info(pretrained_log_msg) + return reward_func_fqn + try: + snapshot_download(reward_func_fqn, repo_type="model") + LOG.info(pretrained_log_msg) + return reward_func_fqn + except HTTPError: + raise ValueError( + f"Reward function {reward_func_fqn} not found." + ) from exc