feat:openenv rollout_func (#3239) [skip ci]
* feat:openenv rollout_func * chore lint * docs * add:docs processing_class * tests * lint
This commit is contained in:
112
docs/rlhf.qmd
112
docs/rlhf.qmd
@@ -597,6 +597,118 @@ To see other examples of custom reward functions, please see [TRL GRPO Docs](htt
|
|||||||
|
|
||||||
To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py).
|
To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py).
|
||||||
|
|
||||||
|
#### OpenEnv Rollout Functions
|
||||||
|
```bash
|
||||||
|
pip insatll openenv-core```
|
||||||
|
|
||||||
|
GRPO supports custom rollout functions for OpenEnv-style environments, enabling interactive tasks like web browsing, code execution, or tool use. This allows you to implement custom generation logic that interacts with external environments.
|
||||||
|
|
||||||
|
For example, to implement a simple math-solving environment with step-by-step verification:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# math_env.py
|
||||||
|
import re
|
||||||
|
|
||||||
|
def math_solver_rollout(model, processing_class, prompts, generation_config=None):
|
||||||
|
"""
|
||||||
|
Custom rollout function that generates step-by-step math solutions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The language model
|
||||||
|
processing_class: The tokenizer/processing_class
|
||||||
|
prompts: List of prompt dicts (with 'messages' key for chat format)
|
||||||
|
generation_config: Optional generation configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of completion strings
|
||||||
|
"""
|
||||||
|
completions = []
|
||||||
|
|
||||||
|
for prompt in prompts:
|
||||||
|
# Apply chat template to prompt
|
||||||
|
messages = prompt.get("messages", [])
|
||||||
|
formatted_prompt = processing_class.apply_chat_template(
|
||||||
|
messages, processing_class=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate step-by-step solution
|
||||||
|
full_response = ""
|
||||||
|
for step in range(5): # Max 5 reasoning steps
|
||||||
|
current_input = formatted_prompt + full_response + "\nNext step:"
|
||||||
|
inputs = processing_class(current_input, return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
outputs = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=100,
|
||||||
|
generation_config=generation_config,
|
||||||
|
)
|
||||||
|
step_text = processing_class.decode(
|
||||||
|
outputs[0][inputs.input_ids.shape[1]:],
|
||||||
|
skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if solution is complete
|
||||||
|
if "FINAL ANSWER:" in step_text:
|
||||||
|
full_response += step_text
|
||||||
|
break
|
||||||
|
full_response += step_text + "\n"
|
||||||
|
|
||||||
|
completions.append(full_response)
|
||||||
|
|
||||||
|
return completions
|
||||||
|
|
||||||
|
def math_reward(prompts, completions, answers, **kwargs):
|
||||||
|
"""Reward function that checks mathematical correctness"""
|
||||||
|
rewards = []
|
||||||
|
for completion, correct_answer in zip(completions, answers):
|
||||||
|
# Extract predicted answer
|
||||||
|
match = re.search(r"FINAL ANSWER:\s*(.+)", completion)
|
||||||
|
predicted = match.group(1).strip() if match else ""
|
||||||
|
|
||||||
|
# Compare with correct answer
|
||||||
|
reward = 1.0 if predicted == str(correct_answer) else 0.0
|
||||||
|
rewards.append(reward)
|
||||||
|
|
||||||
|
return rewards
|
||||||
|
|
||||||
|
def math_transform(cfg, *args, **kwargs):
|
||||||
|
"""Transform dataset to GRPO format with answer field"""
|
||||||
|
def transform_fn(example, processing_class=None):
|
||||||
|
return {
|
||||||
|
"prompt": [{"role": "user", "content": example["question"]}],
|
||||||
|
"answer": str(example["answer"]),
|
||||||
|
}
|
||||||
|
return transform_fn, {"remove_columns": ["question"]}
|
||||||
|
```
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
rl: grpo
|
||||||
|
|
||||||
|
trl:
|
||||||
|
beta: 0.001
|
||||||
|
max_completion_length: 512
|
||||||
|
num_generations: 4
|
||||||
|
rollout_func: "math_env.math_solver_rollout" # Custom rollout function
|
||||||
|
reward_funcs: ["math_env.math_reward"]
|
||||||
|
reward_weights: [1.0]
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: openai/gsm8k
|
||||||
|
name: main
|
||||||
|
type: math_env.math_transform
|
||||||
|
```
|
||||||
|
|
||||||
|
The `rollout_func` parameter accepts a fully qualified name (e.g., `module_name.function_name`) that points to a callable function in your local directory. The function receives:
|
||||||
|
|
||||||
|
- `model`: The language model
|
||||||
|
- `processing_class`: The tokenizer/processing class
|
||||||
|
- `prompts`: List of prompt dictionaries
|
||||||
|
- `generation_config` (optional): Generation configuration
|
||||||
|
|
||||||
|
And should return a list of completion strings.
|
||||||
|
|
||||||
|
For more OpenEnv examples, see [TRL OpenEnv Documentation](https://huggingface.co/docs/trl/main/en/openenv).
|
||||||
|
|
||||||
#### GRPO with DAPO/Dr. GRPO loss
|
#### GRPO with DAPO/Dr. GRPO loss
|
||||||
|
|
||||||
The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.
|
The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.
|
||||||
|
|||||||
@@ -126,6 +126,9 @@ class GRPOStrategy:
|
|||||||
if trl.use_liger_loss is not None:
|
if trl.use_liger_loss is not None:
|
||||||
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
|
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
|
||||||
|
|
||||||
|
if trl.rollout_func:
|
||||||
|
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
|
||||||
|
|
||||||
return grpo_args_kwargs
|
return grpo_args_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -201,3 +204,32 @@ class GRPOStrategy:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Reward function {reward_func_fqn} not found."
|
f"Reward function {reward_func_fqn} not found."
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_rollout_func(cls, rollout_func_fqn: str):
|
||||||
|
"""
|
||||||
|
Returns the rollout function from the given fully qualified name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rollout_func_fqn (str): Fully qualified name of the rollout function
|
||||||
|
(e.g. my_module.my_rollout_func)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable rollout function
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
rollout_func_module_name = rollout_func_fqn.split(".")[-1]
|
||||||
|
rollout_func_module = importlib.import_module(
|
||||||
|
".".join(rollout_func_fqn.split(".")[:-1])
|
||||||
|
)
|
||||||
|
rollout_func = getattr(rollout_func_module, rollout_func_module_name)
|
||||||
|
|
||||||
|
if not callable(rollout_func):
|
||||||
|
raise ValueError(
|
||||||
|
f"Rollout function {rollout_func_fqn} must be callable"
|
||||||
|
)
|
||||||
|
|
||||||
|
return rollout_func
|
||||||
|
|
||||||
|
except ModuleNotFoundError as exc:
|
||||||
|
raise ValueError(f"Rollout function {rollout_func_fqn} not found.") from exc
|
||||||
|
|||||||
@@ -173,3 +173,9 @@ class TRLConfig(BaseModel):
|
|||||||
"description": "Enable sleep mode for vLLM to offload VRAM when idle"
|
"description": "Enable sleep mode for vLLM to offload VRAM when idle"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
rollout_func: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Path to custom rollout function. Must be importable from current dir."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
18
tests/utils/test_grpo_rw_fnc.py
Normal file
18
tests/utils/test_grpo_rw_fnc.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_rollout_func_loads_successfully():
|
||||||
|
"""Test that a valid rollout function can be loaded"""
|
||||||
|
rollout_func = GRPOStrategy.get_rollout_func("os.path.join")
|
||||||
|
assert callable(rollout_func)
|
||||||
|
assert rollout_func == os.path.join
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_rollout_func_invalid_module_raises_error():
|
||||||
|
"""Test that invalid module path raises clear ValueError"""
|
||||||
|
with pytest.raises(ValueError, match="Rollout function .* not found"):
|
||||||
|
GRPOStrategy.get_rollout_func("nonexistent_module.my_func")
|
||||||
Reference in New Issue
Block a user