diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 0d3e7b7d7..c82fe69f2 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -135,7 +135,9 @@ class GRPOStrategy: try: # use importlib to dynamically load the reward function from the module reward_func_module_name = reward_func_fqn.split(".")[-1] - reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2]) + reward_func_module = importlib.import_module( + ".".join(reward_func_fqn.split(".")[:-1]) + ) reward_func = getattr(reward_func_module, reward_func_module_name) if not len(inspect.signature(reward_func).parameters) >= 2: raise ValueError(