feat:openenv rollout_func (#3239) [skip ci]

* feat:openenv rollout_func

* chore lint

* docs

* add:docs processing_class

* tests

* lint
This commit is contained in:
VED
2025-11-07 19:21:40 +05:30
committed by GitHub
parent 80270a92fa
commit ed2e8cacd6
4 changed files with 168 additions and 0 deletions

View File

@@ -597,6 +597,118 @@ To see other examples of custom reward functions, please see [TRL GRPO Docs](htt
To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py).
#### OpenEnv Rollout Functions
```bash
pip insatll openenv-core```
GRPO supports custom rollout functions for OpenEnv-style environments, enabling interactive tasks like web browsing, code execution, or tool use. This allows you to implement custom generation logic that interacts with external environments.
For example, to implement a simple math-solving environment with step-by-step verification:
```python
# math_env.py
import re
def math_solver_rollout(model, processing_class, prompts, generation_config=None):
"""
Custom rollout function that generates step-by-step math solutions.
Args:
model: The language model
processing_class: The tokenizer/processing_class
prompts: List of prompt dicts (with 'messages' key for chat format)
generation_config: Optional generation configuration
Returns:
List of completion strings
"""
completions = []
for prompt in prompts:
# Apply chat template to prompt
messages = prompt.get("messages", [])
formatted_prompt = processing_class.apply_chat_template(
messages, processing_class=False, add_generation_prompt=True
)
# Generate step-by-step solution
full_response = ""
for step in range(5): # Max 5 reasoning steps
current_input = formatted_prompt + full_response + "\nNext step:"
inputs = processing_class(current_input, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=100,
generation_config=generation_config,
)
step_text = processing_class.decode(
outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
# Check if solution is complete
if "FINAL ANSWER:" in step_text:
full_response += step_text
break
full_response += step_text + "\n"
completions.append(full_response)
return completions
def math_reward(prompts, completions, answers, **kwargs):
"""Reward function that checks mathematical correctness"""
rewards = []
for completion, correct_answer in zip(completions, answers):
# Extract predicted answer
match = re.search(r"FINAL ANSWER:\s*(.+)", completion)
predicted = match.group(1).strip() if match else ""
# Compare with correct answer
reward = 1.0 if predicted == str(correct_answer) else 0.0
rewards.append(reward)
return rewards
def math_transform(cfg, *args, **kwargs):
"""Transform dataset to GRPO format with answer field"""
def transform_fn(example, processing_class=None):
return {
"prompt": [{"role": "user", "content": example["question"]}],
"answer": str(example["answer"]),
}
return transform_fn, {"remove_columns": ["question"]}
```
```yaml
rl: grpo
trl:
beta: 0.001
max_completion_length: 512
num_generations: 4
rollout_func: "math_env.math_solver_rollout" # Custom rollout function
reward_funcs: ["math_env.math_reward"]
reward_weights: [1.0]
datasets:
- path: openai/gsm8k
name: main
type: math_env.math_transform
```
The `rollout_func` parameter accepts a fully qualified name (e.g., `module_name.function_name`) that points to a callable function in your local directory. The function receives:
- `model`: The language model
- `processing_class`: The tokenizer/processing class
- `prompts`: List of prompt dictionaries
- `generation_config` (optional): Generation configuration
And should return a list of completion strings.
For more OpenEnv examples, see [TRL OpenEnv Documentation](https://huggingface.co/docs/trl/main/en/openenv).
#### GRPO with DAPO/Dr. GRPO loss
The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.

View File

@@ -126,6 +126,9 @@ class GRPOStrategy:
if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
if trl.rollout_func:
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
return grpo_args_kwargs
@classmethod
@@ -201,3 +204,32 @@ class GRPOStrategy:
raise ValueError(
f"Reward function {reward_func_fqn} not found."
) from exc
@classmethod
def get_rollout_func(cls, rollout_func_fqn: str):
"""
Returns the rollout function from the given fully qualified name.
Args:
rollout_func_fqn (str): Fully qualified name of the rollout function
(e.g. my_module.my_rollout_func)
Returns:
Callable rollout function
"""
try:
rollout_func_module_name = rollout_func_fqn.split(".")[-1]
rollout_func_module = importlib.import_module(
".".join(rollout_func_fqn.split(".")[:-1])
)
rollout_func = getattr(rollout_func_module, rollout_func_module_name)
if not callable(rollout_func):
raise ValueError(
f"Rollout function {rollout_func_fqn} must be callable"
)
return rollout_func
except ModuleNotFoundError as exc:
raise ValueError(f"Rollout function {rollout_func_fqn} not found.") from exc

View File

@@ -173,3 +173,9 @@ class TRLConfig(BaseModel):
"description": "Enable sleep mode for vLLM to offload VRAM when idle"
},
)
rollout_func: str | None = Field(
default=None,
json_schema_extra={
"description": "Path to custom rollout function. Must be importable from current dir."
},
)

View File

@@ -0,0 +1,18 @@
import os
import pytest
from axolotl.core.trainers.grpo import GRPOStrategy
def test_get_rollout_func_loads_successfully():
"""Test that a valid rollout function can be loaded"""
rollout_func = GRPOStrategy.get_rollout_func("os.path.join")
assert callable(rollout_func)
assert rollout_func == os.path.join
def test_get_rollout_func_invalid_module_raises_error():
"""Test that invalid module path raises clear ValueError"""
with pytest.raises(ValueError, match="Rollout function .* not found"):
GRPOStrategy.get_rollout_func("nonexistent_module.my_func")