feat:openenv rollout_func (#3239) [skip ci]

* feat:openenv rollout_func

* chore lint

* docs

* add:docs processing_class

* tests

* lint
This commit is contained in:
VED
2025-11-07 19:21:40 +05:30
committed by GitHub
parent 80270a92fa
commit ed2e8cacd6
4 changed files with 168 additions and 0 deletions

View File

@@ -597,6 +597,118 @@ To see other examples of custom reward functions, please see [TRL GRPO Docs](htt
To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py). To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py).
#### OpenEnv Rollout Functions
```bash
pip insatll openenv-core```
GRPO supports custom rollout functions for OpenEnv-style environments, enabling interactive tasks like web browsing, code execution, or tool use. This allows you to implement custom generation logic that interacts with external environments.
For example, to implement a simple math-solving environment with step-by-step verification:
```python
# math_env.py
import re
def math_solver_rollout(model, processing_class, prompts, generation_config=None):
"""
Custom rollout function that generates step-by-step math solutions.
Args:
model: The language model
processing_class: The tokenizer/processing_class
prompts: List of prompt dicts (with 'messages' key for chat format)
generation_config: Optional generation configuration
Returns:
List of completion strings
"""
completions = []
for prompt in prompts:
# Apply chat template to prompt
messages = prompt.get("messages", [])
formatted_prompt = processing_class.apply_chat_template(
messages, processing_class=False, add_generation_prompt=True
)
# Generate step-by-step solution
full_response = ""
for step in range(5): # Max 5 reasoning steps
current_input = formatted_prompt + full_response + "\nNext step:"
inputs = processing_class(current_input, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=100,
generation_config=generation_config,
)
step_text = processing_class.decode(
outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
# Check if solution is complete
if "FINAL ANSWER:" in step_text:
full_response += step_text
break
full_response += step_text + "\n"
completions.append(full_response)
return completions
def math_reward(prompts, completions, answers, **kwargs):
"""Reward function that checks mathematical correctness"""
rewards = []
for completion, correct_answer in zip(completions, answers):
# Extract predicted answer
match = re.search(r"FINAL ANSWER:\s*(.+)", completion)
predicted = match.group(1).strip() if match else ""
# Compare with correct answer
reward = 1.0 if predicted == str(correct_answer) else 0.0
rewards.append(reward)
return rewards
def math_transform(cfg, *args, **kwargs):
"""Transform dataset to GRPO format with answer field"""
def transform_fn(example, processing_class=None):
return {
"prompt": [{"role": "user", "content": example["question"]}],
"answer": str(example["answer"]),
}
return transform_fn, {"remove_columns": ["question"]}
```
```yaml
rl: grpo
trl:
beta: 0.001
max_completion_length: 512
num_generations: 4
rollout_func: "math_env.math_solver_rollout" # Custom rollout function
reward_funcs: ["math_env.math_reward"]
reward_weights: [1.0]
datasets:
- path: openai/gsm8k
name: main
type: math_env.math_transform
```
The `rollout_func` parameter accepts a fully qualified name (e.g., `module_name.function_name`) that points to a callable function in your local directory. The function receives:
- `model`: The language model
- `processing_class`: The tokenizer/processing class
- `prompts`: List of prompt dictionaries
- `generation_config` (optional): Generation configuration
And should return a list of completion strings.
For more OpenEnv examples, see [TRL OpenEnv Documentation](https://huggingface.co/docs/trl/main/en/openenv).
#### GRPO with DAPO/Dr. GRPO loss #### GRPO with DAPO/Dr. GRPO loss
The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses. The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.

View File

@@ -126,6 +126,9 @@ class GRPOStrategy:
if trl.use_liger_loss is not None: if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
if trl.rollout_func:
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
return grpo_args_kwargs return grpo_args_kwargs
@classmethod @classmethod
@@ -201,3 +204,32 @@ class GRPOStrategy:
raise ValueError( raise ValueError(
f"Reward function {reward_func_fqn} not found." f"Reward function {reward_func_fqn} not found."
) from exc ) from exc
@classmethod
def get_rollout_func(cls, rollout_func_fqn: str):
"""
Returns the rollout function from the given fully qualified name.
Args:
rollout_func_fqn (str): Fully qualified name of the rollout function
(e.g. my_module.my_rollout_func)
Returns:
Callable rollout function
"""
try:
rollout_func_module_name = rollout_func_fqn.split(".")[-1]
rollout_func_module = importlib.import_module(
".".join(rollout_func_fqn.split(".")[:-1])
)
rollout_func = getattr(rollout_func_module, rollout_func_module_name)
if not callable(rollout_func):
raise ValueError(
f"Rollout function {rollout_func_fqn} must be callable"
)
return rollout_func
except ModuleNotFoundError as exc:
raise ValueError(f"Rollout function {rollout_func_fqn} not found.") from exc

View File

@@ -173,3 +173,9 @@ class TRLConfig(BaseModel):
"description": "Enable sleep mode for vLLM to offload VRAM when idle" "description": "Enable sleep mode for vLLM to offload VRAM when idle"
}, },
) )
rollout_func: str | None = Field(
default=None,
json_schema_extra={
"description": "Path to custom rollout function. Must be importable from current dir."
},
)

View File

@@ -0,0 +1,18 @@
import os
import pytest
from axolotl.core.trainers.grpo import GRPOStrategy
def test_get_rollout_func_loads_successfully():
"""Test that a valid rollout function can be loaded"""
rollout_func = GRPOStrategy.get_rollout_func("os.path.join")
assert callable(rollout_func)
assert rollout_func == os.path.join
def test_get_rollout_func_invalid_module_raises_error():
"""Test that invalid module path raises clear ValueError"""
with pytest.raises(ValueError, match="Rollout function .* not found"):
GRPOStrategy.get_rollout_func("nonexistent_module.my_func")