better handling for reward function checks for GRPO (#2933) [skip ci]
* better handling for reward function checks for GRPO * consolidate msg copy
This commit is contained in:
@@ -2,8 +2,11 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from requests import HTTPError
|
||||||
from trl.trainer.grpo_trainer import RewardFunc
|
from trl.trainer.grpo_trainer import RewardFunc
|
||||||
|
|
||||||
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||||
@@ -178,9 +181,18 @@ class GRPOStrategy:
|
|||||||
"Reward function must accept at least two arguments: prompts: list and completions: list"
|
"Reward function must accept at least two arguments: prompts: list and completions: list"
|
||||||
)
|
)
|
||||||
return reward_func
|
return reward_func
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError as exc:
|
||||||
# the user has passed a string (ideally indicating the path of a reward model)
|
# the user has passed a string (ideally indicating the path of a reward model)
|
||||||
LOG.info(
|
# check if it's a local dir path and not empty dir to a reward model
|
||||||
f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path."
|
pretrained_log_msg = f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path."
|
||||||
)
|
if os.path.isdir(reward_func_fqn) and os.listdir(reward_func_fqn):
|
||||||
return reward_func_fqn
|
LOG.info(pretrained_log_msg)
|
||||||
|
return reward_func_fqn
|
||||||
|
try:
|
||||||
|
snapshot_download(reward_func_fqn, repo_type="model")
|
||||||
|
LOG.info(pretrained_log_msg)
|
||||||
|
return reward_func_fqn
|
||||||
|
except HTTPError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Reward function {reward_func_fqn} not found."
|
||||||
|
) from exc
|
||||||
|
|||||||
Reference in New Issue
Block a user