better handling for reward function checks for GRPO (#2933) [skip ci]
* better handling for reward function checks for GRPO * consolidate msg copy
This commit is contained in:
@@ -2,8 +2,11 @@
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from requests import HTTPError
|
||||
from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||
@@ -178,9 +181,18 @@ class GRPOStrategy:
|
||||
"Reward function must accept at least two arguments: prompts: list and completions: list"
|
||||
)
|
||||
return reward_func
|
||||
except ModuleNotFoundError:
|
||||
except ModuleNotFoundError as exc:
|
||||
# the user has passed a string (ideally indicating the path of a reward model)
|
||||
LOG.info(
|
||||
f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path."
|
||||
)
|
||||
return reward_func_fqn
|
||||
# check if it's a local dir path and not empty dir to a reward model
|
||||
pretrained_log_msg = f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path."
|
||||
if os.path.isdir(reward_func_fqn) and os.listdir(reward_func_fqn):
|
||||
LOG.info(pretrained_log_msg)
|
||||
return reward_func_fqn
|
||||
try:
|
||||
snapshot_download(reward_func_fqn, repo_type="model")
|
||||
LOG.info(pretrained_log_msg)
|
||||
return reward_func_fqn
|
||||
except HTTPError:
|
||||
raise ValueError(
|
||||
f"Reward function {reward_func_fqn} not found."
|
||||
) from exc
|
||||
|
||||
Reference in New Issue
Block a user