diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 3e824f705..2c6eb8c6f 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -2,8 +2,11 @@ import importlib import inspect +import os from typing import Any +from huggingface_hub import snapshot_download +from requests import HTTPError from trl.trainer.grpo_trainer import RewardFunc from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig @@ -178,9 +181,18 @@ class GRPOStrategy: "Reward function must accept at least two arguments: prompts: list and completions: list" ) return reward_func - except ModuleNotFoundError: + except ModuleNotFoundError as exc: # the user has passed a string (ideally indicating the path of a reward model) - LOG.info( - f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path." - ) - return reward_func_fqn + # check if it's a local dir path and not empty dir to a reward model + pretrained_log_msg = f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path." + if os.path.isdir(reward_func_fqn) and os.listdir(reward_func_fqn): + LOG.info(pretrained_log_msg) + return reward_func_fqn + try: + snapshot_download(reward_func_fqn, repo_type="model") + LOG.info(pretrained_log_msg) + return reward_func_fqn + except HTTPError: + raise ValueError( + f"Reward function {reward_func_fqn} not found." + ) from exc