feat:openenv rollout_func (#3239) [skip ci]
* feat:openenv rollout_func * chore lint * docs * add:docs processing_class * tests * lint
This commit is contained in:
112
docs/rlhf.qmd
112
docs/rlhf.qmd
@@ -597,6 +597,118 @@ To see other examples of custom reward functions, please see [TRL GRPO Docs](htt
|
||||
|
||||
To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py).
|
||||
|
||||
#### OpenEnv Rollout Functions
|
||||
```bash
|
||||
pip insatll openenv-core```
|
||||
|
||||
GRPO supports custom rollout functions for OpenEnv-style environments, enabling interactive tasks like web browsing, code execution, or tool use. This allows you to implement custom generation logic that interacts with external environments.
|
||||
|
||||
For example, to implement a simple math-solving environment with step-by-step verification:
|
||||
|
||||
```python
|
||||
# math_env.py
|
||||
import re
|
||||
|
||||
def math_solver_rollout(model, processing_class, prompts, generation_config=None):
|
||||
"""
|
||||
Custom rollout function that generates step-by-step math solutions.
|
||||
|
||||
Args:
|
||||
model: The language model
|
||||
processing_class: The tokenizer/processing_class
|
||||
prompts: List of prompt dicts (with 'messages' key for chat format)
|
||||
generation_config: Optional generation configuration
|
||||
|
||||
Returns:
|
||||
List of completion strings
|
||||
"""
|
||||
completions = []
|
||||
|
||||
for prompt in prompts:
|
||||
# Apply chat template to prompt
|
||||
messages = prompt.get("messages", [])
|
||||
formatted_prompt = processing_class.apply_chat_template(
|
||||
messages, processing_class=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Generate step-by-step solution
|
||||
full_response = ""
|
||||
for step in range(5): # Max 5 reasoning steps
|
||||
current_input = formatted_prompt + full_response + "\nNext step:"
|
||||
inputs = processing_class(current_input, return_tensors="pt").to(model.device)
|
||||
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=100,
|
||||
generation_config=generation_config,
|
||||
)
|
||||
step_text = processing_class.decode(
|
||||
outputs[0][inputs.input_ids.shape[1]:],
|
||||
skip_special_tokens=True
|
||||
)
|
||||
|
||||
# Check if solution is complete
|
||||
if "FINAL ANSWER:" in step_text:
|
||||
full_response += step_text
|
||||
break
|
||||
full_response += step_text + "\n"
|
||||
|
||||
completions.append(full_response)
|
||||
|
||||
return completions
|
||||
|
||||
def math_reward(prompts, completions, answers, **kwargs):
|
||||
"""Reward function that checks mathematical correctness"""
|
||||
rewards = []
|
||||
for completion, correct_answer in zip(completions, answers):
|
||||
# Extract predicted answer
|
||||
match = re.search(r"FINAL ANSWER:\s*(.+)", completion)
|
||||
predicted = match.group(1).strip() if match else ""
|
||||
|
||||
# Compare with correct answer
|
||||
reward = 1.0 if predicted == str(correct_answer) else 0.0
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
def math_transform(cfg, *args, **kwargs):
|
||||
"""Transform dataset to GRPO format with answer field"""
|
||||
def transform_fn(example, processing_class=None):
|
||||
return {
|
||||
"prompt": [{"role": "user", "content": example["question"]}],
|
||||
"answer": str(example["answer"]),
|
||||
}
|
||||
return transform_fn, {"remove_columns": ["question"]}
|
||||
```
|
||||
|
||||
```yaml
|
||||
rl: grpo
|
||||
|
||||
trl:
|
||||
beta: 0.001
|
||||
max_completion_length: 512
|
||||
num_generations: 4
|
||||
rollout_func: "math_env.math_solver_rollout" # Custom rollout function
|
||||
reward_funcs: ["math_env.math_reward"]
|
||||
reward_weights: [1.0]
|
||||
|
||||
datasets:
|
||||
- path: openai/gsm8k
|
||||
name: main
|
||||
type: math_env.math_transform
|
||||
```
|
||||
|
||||
The `rollout_func` parameter accepts a fully qualified name (e.g., `module_name.function_name`) that points to a callable function in your local directory. The function receives:
|
||||
|
||||
- `model`: The language model
|
||||
- `processing_class`: The tokenizer/processing class
|
||||
- `prompts`: List of prompt dictionaries
|
||||
- `generation_config` (optional): Generation configuration
|
||||
|
||||
And should return a list of completion strings.
|
||||
|
||||
For more OpenEnv examples, see [TRL OpenEnv Documentation](https://huggingface.co/docs/trl/main/en/openenv).
|
||||
|
||||
#### GRPO with DAPO/Dr. GRPO loss
|
||||
|
||||
The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.
|
||||
|
||||
Reference in New Issue
Block a user