feat:openenv rollout_func (#3239) [skip ci]

* feat:openenv rollout_func

* chore lint

* docs

* add:docs processing_class

* tests

* lint
This commit is contained in:
VED
2025-11-07 19:21:40 +05:30
committed by GitHub
parent 80270a92fa
commit ed2e8cacd6
4 changed files with 168 additions and 0 deletions

View File

@@ -126,6 +126,9 @@ class GRPOStrategy:
if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
if trl.rollout_func:
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
return grpo_args_kwargs
@classmethod
@@ -201,3 +204,32 @@ class GRPOStrategy:
raise ValueError(
f"Reward function {reward_func_fqn} not found."
) from exc
@classmethod
def get_rollout_func(cls, rollout_func_fqn: str):
"""
Returns the rollout function from the given fully qualified name.
Args:
rollout_func_fqn (str): Fully qualified name of the rollout function
(e.g. my_module.my_rollout_func)
Returns:
Callable rollout function
"""
try:
rollout_func_module_name = rollout_func_fqn.split(".")[-1]
rollout_func_module = importlib.import_module(
".".join(rollout_func_fqn.split(".")[:-1])
)
rollout_func = getattr(rollout_func_module, rollout_func_module_name)
if not callable(rollout_func):
raise ValueError(
f"Rollout function {rollout_func_fqn} must be callable"
)
return rollout_func
except ModuleNotFoundError as exc:
raise ValueError(f"Rollout function {rollout_func_fqn} not found.") from exc

View File

@@ -173,3 +173,9 @@ class TRLConfig(BaseModel):
"description": "Enable sleep mode for vLLM to offload VRAM when idle"
},
)
rollout_func: str | None = Field(
default=None,
json_schema_extra={
"description": "Path to custom rollout function. Must be importable from current dir."
},
)