* gdpo support - test left * lint * fixxes for vllm serv * test advantages * docss * lint * lint = * gdpo simple + lint * lint nit * example * lint * trl 0.27.0 * blocklist * test assert rmv * add validation check for GDPO + sum_then_normalize --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
852 lines
19 KiB
Plaintext
852 lines
19 KiB
Plaintext
---
|
|
title: "RLHF (Beta)"
|
|
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
|
back-to-top-navigation: true
|
|
toc: true
|
|
toc-expand: 2
|
|
toc-depth: 4
|
|
---
|
|
|
|
## Overview
|
|
|
|
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
|
|
feedback. Various methods include, but not limited to:
|
|
|
|
- [Direct Preference Optimization (DPO)](#dpo)
|
|
- [Identity Preference Optimization (IPO)](#ipo)
|
|
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
|
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
|
- [Group Relative Policy Optimization (GRPO)](#grpo)
|
|
- [Group Reward-Decoupled Policy Optimization (GDPO)](#gdpo)
|
|
|
|
|
|
## RLHF using Axolotl
|
|
|
|
::: {.callout-important}
|
|
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
|
:::
|
|
|
|
We rely on the [TRL](https://github.com/huggingface/trl) library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.
|
|
|
|
::: {.callout-tip}
|
|
You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`.
|
|
:::
|
|
|
|
### DPO
|
|
|
|
Example config:
|
|
|
|
```yaml
|
|
rl: dpo
|
|
datasets:
|
|
- path: Intel/orca_dpo_pairs
|
|
split: train
|
|
type: chatml.intel
|
|
- path: argilla/ultrafeedback-binarized-preferences
|
|
split: train
|
|
type: chatml
|
|
```
|
|
|
|
DPO supports the following types with the following dataset format:
|
|
|
|
#### chatml.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"chosen_response": "...",
|
|
"rejected_response": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### chatml.icr
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"input": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### llama3.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"chosen_response": "...",
|
|
"rejected_response": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### llama3.icr
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"input": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### zephyr.nectar
|
|
|
|
```json
|
|
{
|
|
"prompt": "...",
|
|
"answers": [
|
|
{
|
|
"answer": "...",
|
|
"rank": 1
|
|
},
|
|
{
|
|
"answer": "...",
|
|
"rank": 2
|
|
}
|
|
// ... more answers with ranks
|
|
]
|
|
}
|
|
```
|
|
|
|
#### chat_template.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### chat_template.default
|
|
|
|
```yaml
|
|
rl: dpo
|
|
datasets:
|
|
- path: ...
|
|
split: train
|
|
type: chat_template.default
|
|
field_messages: "messages"
|
|
field_chosen: "chosen"
|
|
field_rejected: "rejected"
|
|
message_property_mappings:
|
|
role: role
|
|
content: content
|
|
roles:
|
|
user: ["user"]
|
|
assistant: ["assistant"]
|
|
system: ["system"]
|
|
```
|
|
|
|
Sample input format:
|
|
|
|
```json
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": "..."
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "..."
|
|
},
|
|
// ... more messages
|
|
],
|
|
"chosen": {
|
|
"role": "assistant",
|
|
"content": "..."
|
|
},
|
|
"rejected": {
|
|
"role": "assistant",
|
|
"content": "..."
|
|
}
|
|
}
|
|
```
|
|
|
|
#### user_defined.default
|
|
|
|
For custom behaviors,
|
|
|
|
```yaml
|
|
rl: dpo
|
|
datasets:
|
|
- path: ...
|
|
split: train
|
|
type:
|
|
field_prompt: "prompt"
|
|
field_system: "system"
|
|
field_chosen: "chosen"
|
|
field_rejected: "rejected"
|
|
prompt_format: "{prompt}"
|
|
chosen_format: "{chosen}"
|
|
rejected_format: "{rejected}"
|
|
```
|
|
|
|
The input format is a simple JSON input with customizable fields based on the above config.
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
### IPO
|
|
|
|
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
|
|
|
```yaml
|
|
rl: ipo
|
|
```
|
|
|
|
### ORPO
|
|
|
|
Paper: https://arxiv.org/abs/2403.07691
|
|
|
|
```yaml
|
|
rl: orpo
|
|
orpo_alpha: 0.1
|
|
remove_unused_columns: false
|
|
|
|
chat_template: chatml
|
|
datasets:
|
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
|
type: chat_template.argilla
|
|
```
|
|
|
|
ORPO supports the following types with the following dataset format:
|
|
|
|
#### chat_template.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...", // if available, will be taken as user message for single-turn instead of from list below
|
|
|
|
// chosen/rejected should be same till last content and only even-number of alternating user/assistant turns
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
### KTO
|
|
|
|
```yaml
|
|
rl: kto
|
|
rl_beta: 0.1 # default
|
|
kto_desirable_weight: 1.0 # default
|
|
kto_undesirable_weight: 1.0 # default
|
|
|
|
remove_unused_columns: false
|
|
|
|
datasets:
|
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
|
|
type: llama3.ultra
|
|
split: train
|
|
|
|
gradient_checkpointing: true
|
|
gradient_checkpointing_kwargs:
|
|
use_reentrant: true
|
|
```
|
|
|
|
KTO supports the following types with the following dataset format:
|
|
|
|
#### chatml.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"chosen": [
|
|
{"role": "user", "content": "..."}
|
|
],
|
|
"completion": [
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### chatml.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"completion": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### llama3.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### user_defined.default
|
|
|
|
For custom behaviors,
|
|
|
|
```yaml
|
|
rl: kto
|
|
datasets:
|
|
- path: ...
|
|
split: train
|
|
type:
|
|
field_prompt: "prompt"
|
|
field_system: "system"
|
|
field_completion: "completion"
|
|
field_label: "label"
|
|
prompt_format: "{prompt}"
|
|
completion_format: "{completion}"
|
|
```
|
|
|
|
The input format is a simple JSON input with customizable fields based on the above config.
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "...",
|
|
"label": "..."
|
|
}
|
|
```
|
|
|
|
### GRPO
|
|
|
|
::: {.callout-tip}
|
|
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/grpo_code).
|
|
:::
|
|
|
|
In the latest GRPO implementation, `vLLM` is used to significantly speedup trajectory generation during training. In this example, we're using 4 GPUs - 2 for training, and 2 for vLLM:
|
|
|
|
::: {.callout-important}
|
|
Make sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`.
|
|
:::
|
|
|
|
```yaml
|
|
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
|
|
|
vllm:
|
|
host: 0.0.0.0
|
|
port: 8000
|
|
tensor_parallel_size: 2
|
|
gpu_memory_utilization: 0.85
|
|
dtype: auto
|
|
# max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand
|
|
|
|
rl: grpo
|
|
trl:
|
|
use_vllm: true
|
|
vllm_server_host: 0.0.0.0
|
|
vllm_server_port: 8000
|
|
vllm_server_timeout: 300
|
|
```
|
|
|
|
```bash
|
|
CUDA_VISIBLE_DEVICES=2,3 axolotl vllm-serve grpo.yaml
|
|
```
|
|
|
|
Your `vLLM` instance will now attempt to spin up, and it's time to kick off training utilizing our remaining two GPUs. In another terminal, execute:
|
|
|
|
```bash
|
|
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2
|
|
```
|
|
|
|
::: {.callout-note}
|
|
Due to TRL's implementation with vLLM, the vLLM instance must use the last N GPUs instead of the first N GPUs. This is why in the example above, we use `CUDA_VISIBLE_DEVICES=2,3` for the vLLM instance.
|
|
:::
|
|
|
|
#### Reward functions
|
|
|
|
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
|
|
|
For example, to load OpenAI's GSM8K and use a random reward for completions:
|
|
|
|
```python
|
|
# rewards.py
|
|
import random
|
|
|
|
def rand_reward_func(completions, **kwargs) -> list[float]:
|
|
return [random.uniform(0, 1) for _ in completions]
|
|
|
|
def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|
def transform_fn(example, tokenizer=None):
|
|
label = example["answer"].split("####")[-1].strip().replace(",", "")
|
|
return {
|
|
"prompt": [{"role": "user", "content": example["question"]},],
|
|
"answer": label,
|
|
}
|
|
return transform_fn, {"remove_columns": ["question"]}
|
|
```
|
|
|
|
```yaml
|
|
rl: grpo
|
|
|
|
trl:
|
|
beta: 0.001
|
|
max_completion_length: 256
|
|
use_vllm: True
|
|
num_generations: 4
|
|
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
|
reward_weights: [1.0]
|
|
datasets:
|
|
- path: openai/gsm8k
|
|
name: main
|
|
type: rewards.oai_gsm8k_transform # format: '{file_name}.{fn_name}'
|
|
```
|
|
|
|
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
|
|
|
|
To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py).
|
|
|
|
#### OpenEnv Rollout Functions
|
|
|
|
GRPO supports custom rollout functions for OpenEnv-style environments, enabling interactive tasks like web browsing, code execution, or tool use. This allows you to implement custom generation logic that interacts with external environments.
|
|
|
|
For example, to implement a simple math-solving environment with step-by-step verification:
|
|
|
|
```python
|
|
# math_env.py
|
|
import re
|
|
|
|
def math_solver_rollout(model, processing_class, prompts, generation_config=None):
|
|
"""
|
|
Custom rollout function that generates step-by-step math solutions.
|
|
|
|
Args:
|
|
model: The language model
|
|
processing_class: The tokenizer/processing_class
|
|
prompts: List of prompt dicts (with 'messages' key for chat format)
|
|
generation_config: Optional generation configuration
|
|
|
|
Returns:
|
|
List of completion strings
|
|
"""
|
|
completions = []
|
|
|
|
for prompt in prompts:
|
|
# Apply chat template to prompt
|
|
messages = prompt.get("messages", [])
|
|
formatted_prompt = processing_class.apply_chat_template(
|
|
messages, processing_class=False, add_generation_prompt=True
|
|
)
|
|
|
|
# Generate step-by-step solution
|
|
full_response = ""
|
|
for step in range(5): # Max 5 reasoning steps
|
|
current_input = formatted_prompt + full_response + "\nNext step:"
|
|
inputs = processing_class(current_input, return_tensors="pt").to(model.device)
|
|
|
|
outputs = model.generate(
|
|
**inputs,
|
|
max_new_tokens=100,
|
|
generation_config=generation_config,
|
|
)
|
|
step_text = processing_class.decode(
|
|
outputs[0][inputs.input_ids.shape[1]:],
|
|
skip_special_tokens=True
|
|
)
|
|
|
|
# Check if solution is complete
|
|
if "FINAL ANSWER:" in step_text:
|
|
full_response += step_text
|
|
break
|
|
full_response += step_text + "\n"
|
|
|
|
completions.append(full_response)
|
|
|
|
return completions
|
|
|
|
def math_reward(prompts, completions, answers, **kwargs):
|
|
"""Reward function that checks mathematical correctness"""
|
|
rewards = []
|
|
for completion, correct_answer in zip(completions, answers):
|
|
# Extract predicted answer
|
|
match = re.search(r"FINAL ANSWER:\s*(.+)", completion)
|
|
predicted = match.group(1).strip() if match else ""
|
|
|
|
# Compare with correct answer
|
|
reward = 1.0 if predicted == str(correct_answer) else 0.0
|
|
rewards.append(reward)
|
|
|
|
return rewards
|
|
|
|
def math_transform(cfg, *args, **kwargs):
|
|
"""Transform dataset to GRPO format with answer field"""
|
|
def transform_fn(example, processing_class=None):
|
|
return {
|
|
"prompt": [{"role": "user", "content": example["question"]}],
|
|
"answer": str(example["answer"]),
|
|
}
|
|
return transform_fn, {"remove_columns": ["question"]}
|
|
```
|
|
|
|
```yaml
|
|
rl: grpo
|
|
|
|
trl:
|
|
beta: 0.001
|
|
max_completion_length: 512
|
|
num_generations: 4
|
|
rollout_func: "math_env.math_solver_rollout" # Custom rollout function
|
|
reward_funcs: ["math_env.math_reward"]
|
|
reward_weights: [1.0]
|
|
|
|
datasets:
|
|
- path: openai/gsm8k
|
|
name: main
|
|
type: math_env.math_transform
|
|
```
|
|
|
|
The `rollout_func` parameter accepts a fully qualified name (e.g., `module_name.function_name`) that points to a callable function in your local directory. The function receives:
|
|
|
|
- `model`: The language model
|
|
- `processing_class`: The tokenizer/processing class
|
|
- `prompts`: List of prompt dictionaries
|
|
- `generation_config` (optional): Generation configuration
|
|
|
|
And should return a list of completion strings.
|
|
|
|
For more OpenEnv examples, see [TRL OpenEnv Documentation](https://huggingface.co/docs/trl/main/en/openenv).
|
|
|
|
#### GRPO with DAPO/Dr. GRPO loss
|
|
|
|
The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.
|
|
|
|
```yaml
|
|
trl:
|
|
loss_type: dr_grpo
|
|
# Normalizes loss based on max completion length (default: 256)
|
|
max_completion_length:
|
|
```
|
|
|
|
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
|
|
|
|
### GDPO
|
|
|
|
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.
|
|
|
|
::: {.callout-tip}
|
|
Use GDPO when training with multiple reward functions. For single reward, GRPO and GDPO produce equivalent results.
|
|
:::
|
|
|
|
Paper: [https://arxiv.org/pdf/2501.05242](https://arxiv.org/pdf/2501.05242)
|
|
|
|
GDPO uses TRL's native `multi_objective_aggregation` parameter under the hood. When you set `rl: gdpo`, axolotl automatically configures TRL to use `normalize_then_sum` aggregation.
|
|
|
|
```yaml
|
|
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
|
|
|
vllm:
|
|
host: 0.0.0.0
|
|
port: 8000
|
|
tensor_parallel_size: 2
|
|
gpu_memory_utilization: 0.85
|
|
|
|
rl: gdpo
|
|
|
|
trl:
|
|
beta: 0.001
|
|
max_completion_length: 256
|
|
use_vllm: true
|
|
num_generations: 4
|
|
reward_funcs:
|
|
- rewards.format_reward
|
|
- rewards.correctness_reward
|
|
reward_weights: [1.0, 2.0]
|
|
|
|
datasets:
|
|
- path: openai/gsm8k
|
|
name: main
|
|
type: rewards.oai_gsm8k_transform
|
|
```
|
|
|
|
You can also use GRPO with explicit aggregation control:
|
|
|
|
```yaml
|
|
rl: grpo
|
|
trl:
|
|
multi_objective_aggregation: normalize_then_sum # GDPO behavior
|
|
# or: sum_then_normalize # Default GRPO behavior
|
|
```
|
|
|
|
#### GDPO vs GRPO
|
|
|
|
| Aspect | GRPO | GDPO |
|
|
|--------|------|------|
|
|
| **Aggregation** | `sum_then_normalize` | `normalize_then_sum` |
|
|
| **Multi-reward** | May collapse advantages | Preserves reward signals |
|
|
| **Single reward** | Standard behavior | Equivalent to GRPO |
|
|
|
|
#### Why GDPO?
|
|
|
|
When using multiple rewards with GRPO, different reward combinations can produce identical advantages:
|
|
|
|
```
|
|
# Example: format + correctness rewards
|
|
[format=0, correct=3] → sum=3
|
|
[format=1, correct=2] → sum=3 ← GRPO sees these as equal!
|
|
[format=2, correct=1] → sum=3
|
|
[format=3, correct=0] → sum=3
|
|
```
|
|
|
|
GDPO normalizes each reward independently, preserving their relative differences.
|
|
|
|
#### Reward Functions
|
|
|
|
GDPO uses the same reward function format as GRPO:
|
|
|
|
```python
|
|
# rewards.py
|
|
def format_reward(completions, **kwargs) -> list[float]:
|
|
return [1.0 if len(c) > 10 else 0.0 for c in completions]
|
|
|
|
def correctness_reward(completions, answers, **kwargs) -> list[float]:
|
|
rewards = []
|
|
for completion, answer in zip(completions, answers):
|
|
# Your scoring logic here
|
|
rewards.append(score)
|
|
return rewards
|
|
```
|
|
|
|
#### Sequence Parallelism
|
|
|
|
GDPO supports sequence parallelism for long-context training:
|
|
|
|
```yaml
|
|
rl: gdpo
|
|
context_parallel_size: 2
|
|
```
|
|
|
|
### SimPO
|
|
|
|
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
|
|
|
|
```yaml
|
|
rl: simpo
|
|
rl_beta: 0.1 # default in CPOTrainer
|
|
cpo_alpha: 1.0 # default in CPOTrainer
|
|
simpo_gamma: 0.5 # default in CPOTrainer
|
|
```
|
|
|
|
This method uses the same dataset format as [DPO](#dpo).
|
|
|
|
### Using local dataset files
|
|
|
|
```yaml
|
|
datasets:
|
|
- ds_type: json
|
|
data_files:
|
|
- orca_rlhf.jsonl
|
|
split: train
|
|
type: chatml.intel
|
|
```
|
|
|
|
### TRL auto-unwrapping for PEFT
|
|
|
|
TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:
|
|
|
|
```yaml
|
|
# load ref model when adapter training.
|
|
rl_adapter_ref_model: true
|
|
```
|