feat:openenv rollout_func (#3239) [skip ci]
* feat:openenv rollout_func * chore lint * docs * add:docs processing_class * tests * lint
This commit is contained in:
18
tests/utils/test_grpo_rw_fnc.py
Normal file
18
tests/utils/test_grpo_rw_fnc.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
|
||||
|
||||
def test_get_rollout_func_loads_successfully():
|
||||
"""Test that a valid rollout function can be loaded"""
|
||||
rollout_func = GRPOStrategy.get_rollout_func("os.path.join")
|
||||
assert callable(rollout_func)
|
||||
assert rollout_func == os.path.join
|
||||
|
||||
|
||||
def test_get_rollout_func_invalid_module_raises_error():
|
||||
"""Test that invalid module path raises clear ValueError"""
|
||||
with pytest.raises(ValueError, match="Rollout function .* not found"):
|
||||
GRPOStrategy.get_rollout_func("nonexistent_module.my_func")
|
||||
Reference in New Issue
Block a user