630 lines
13 KiB
Plaintext
630 lines
13 KiB
Plaintext
---
|
|
title: "RLHF (Beta)"
|
|
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
|
back-to-top-navigation: true
|
|
toc: true
|
|
toc-expand: 2
|
|
toc-depth: 4
|
|
---
|
|
|
|
## Overview
|
|
|
|
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
|
|
feedback. Various methods include, but not limited to:
|
|
|
|
- [Direct Preference Optimization (DPO)](#dpo)
|
|
- [Identity Preference Optimization (IPO)](#ipo)
|
|
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
|
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
|
- [Group Relative Policy Optimization (GRPO)](#grpo)
|
|
|
|
|
|
## RLHF using Axolotl
|
|
|
|
::: {.callout-important}
|
|
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
|
:::
|
|
|
|
We rely on the [TRL](https://github.com/huggingface/trl) library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.
|
|
|
|
::: {.callout-tip}
|
|
You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`.
|
|
:::
|
|
|
|
### DPO
|
|
|
|
Example config:
|
|
|
|
```yaml
|
|
rl: dpo
|
|
datasets:
|
|
- path: Intel/orca_dpo_pairs
|
|
split: train
|
|
type: chatml.intel
|
|
- path: argilla/ultrafeedback-binarized-preferences
|
|
split: train
|
|
type: chatml
|
|
```
|
|
|
|
DPO supports the following types with the following dataset format:
|
|
|
|
#### chatml.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"chosen_response": "...",
|
|
"rejected_response": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### chatml.icr
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"input": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### llama3.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"chosen_response": "...",
|
|
"rejected_response": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### llama3.icr
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"input": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### zephyr.nectar
|
|
|
|
```json
|
|
{
|
|
"prompt": "...",
|
|
"answers": [
|
|
{
|
|
"answer": "...",
|
|
"rank": 1
|
|
},
|
|
{
|
|
"answer": "...",
|
|
"rank": 2
|
|
}
|
|
// ... more answers with ranks
|
|
]
|
|
}
|
|
```
|
|
|
|
#### chat_template.default
|
|
|
|
```yaml
|
|
rl: dpo
|
|
datasets:
|
|
- path: ...
|
|
split: train
|
|
type: chat_template.default
|
|
field_messages: "messages"
|
|
field_chosen: "chosen"
|
|
field_rejected: "rejected"
|
|
message_property_mappings:
|
|
role: role
|
|
content: content
|
|
roles:
|
|
user: ["user"]
|
|
assistant: ["assistant"]
|
|
system: ["system"]
|
|
```
|
|
|
|
Sample input format:
|
|
|
|
```json
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": "..."
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "..."
|
|
},
|
|
// ... more messages
|
|
],
|
|
"chosen": {
|
|
"role": "assistant",
|
|
"content": "..."
|
|
},
|
|
"rejected": {
|
|
"role": "assistant",
|
|
"content": "..."
|
|
}
|
|
}
|
|
```
|
|
|
|
#### user_defined.default
|
|
|
|
For custom behaviors,
|
|
|
|
```yaml
|
|
rl: dpo
|
|
datasets:
|
|
- path: ...
|
|
split: train
|
|
type:
|
|
field_prompt: "prompt"
|
|
field_system: "system"
|
|
field_chosen: "chosen"
|
|
field_rejected: "rejected"
|
|
prompt_format: "{prompt}"
|
|
chosen_format: "{chosen}"
|
|
rejected_format: "{rejected}"
|
|
```
|
|
|
|
The input format is a simple JSON input with customizable fields based on the above config.
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
### IPO
|
|
|
|
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
|
|
|
```yaml
|
|
rl: ipo
|
|
```
|
|
|
|
### ORPO
|
|
|
|
Paper: https://arxiv.org/abs/2403.07691
|
|
|
|
```yaml
|
|
rl: orpo
|
|
orpo_alpha: 0.1
|
|
remove_unused_columns: false
|
|
|
|
chat_template: chatml
|
|
datasets:
|
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
|
type: chat_template.argilla
|
|
```
|
|
|
|
ORPO supports the following types with the following dataset format:
|
|
|
|
#### chat_template.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...", // if available, will be taken as user message for single-turn instead of from list below
|
|
|
|
// chosen/rejected should be same till last content and only even-number of alternating user/assistant turns
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
### KTO
|
|
|
|
```yaml
|
|
rl: kto
|
|
rl_beta: 0.1 # default
|
|
kto_desirable_weight: 1.0 # default
|
|
kto_undesirable_weight: 1.0 # default
|
|
|
|
remove_unused_columns: false
|
|
|
|
datasets:
|
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
|
|
type: llama3.ultra
|
|
split: train
|
|
|
|
gradient_checkpointing: true
|
|
gradient_checkpointing_kwargs:
|
|
use_reentrant: true
|
|
```
|
|
|
|
KTO supports the following types with the following dataset format:
|
|
|
|
#### chatml.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"chosen": [
|
|
{"role": "user", "content": "..."}
|
|
],
|
|
"completion": [
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### chatml.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"completion": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### llama3.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### user_defined.default
|
|
|
|
For custom behaviors,
|
|
|
|
```yaml
|
|
rl: kto
|
|
datasets:
|
|
- path: ...
|
|
split: train
|
|
type:
|
|
field_prompt: "prompt"
|
|
field_system: "system"
|
|
field_completion: "completion"
|
|
field_label: "label"
|
|
prompt_format: "{prompt}"
|
|
completion_format: "{completion}"
|
|
```
|
|
|
|
The input format is a simple JSON input with customizable fields based on the above config.
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "...",
|
|
"label": "..."
|
|
}
|
|
```
|
|
|
|
### GRPO
|
|
|
|
::: {.callout-tip}
|
|
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/grpo_code).
|
|
:::
|
|
|
|
In the latest GRPO implementation, `vLLM` is used to significantly speedup trajectory generation during training. In this example, we're using 4 GPUs - 2 for training, and 2 for vLLM:
|
|
|
|
::: {.callout-important}
|
|
Make sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`.
|
|
:::
|
|
|
|
```yaml
|
|
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
|
|
|
vllm:
|
|
host: 0.0.0.0
|
|
port: 8000
|
|
tensor_parallel_size: 2
|
|
gpu_memory_utilization: 0.85
|
|
dtype: auto
|
|
# max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand
|
|
|
|
rl: grpo
|
|
trl:
|
|
use_vllm: true
|
|
vllm_server_host: 0.0.0.0
|
|
vllm_server_port: 8000
|
|
vllm_server_timeout: 300
|
|
```
|
|
|
|
```bash
|
|
CUDA_VISIBLE_DEVICES=2,3 axolotl vllm-serve grpo.yaml
|
|
```
|
|
|
|
Your `vLLM` instance will now attempt to spin up, and it's time to kick off training utilizing our remaining two GPUs. In another terminal, execute:
|
|
|
|
```bash
|
|
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2
|
|
```
|
|
|
|
::: {.callout-note}
|
|
Due to TRL's implementation with vLLM, the vLLM instance must use the last N GPUs instead of the first N GPUs. This is why in the example above, we use `CUDA_VISIBLE_DEVICES=2,3` for the vLLM instance.
|
|
:::
|
|
|
|
#### Reward functions
|
|
|
|
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
|
|
|
For example, to load OpenAI's GSM8K and use a random reward for completions:
|
|
|
|
```python
|
|
# rewards.py
|
|
import random
|
|
|
|
def rand_reward_func(completions, **kwargs) -> list[float]:
|
|
return [random.uniform(0, 1) for _ in completions]
|
|
|
|
def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|
def transform_fn(example, tokenizer=None):
|
|
label = example["answer"].split("####")[-1].strip().replace(",", "")
|
|
return {
|
|
"prompt": [{"role": "user", "content": example["question"]},],
|
|
"answer": label,
|
|
}
|
|
return transform_fn, {"remove_columns": ["question"]}
|
|
```
|
|
|
|
```yaml
|
|
rl: grpo
|
|
|
|
trl:
|
|
beta: 0.001
|
|
max_completion_length: 256
|
|
use_vllm: True
|
|
num_generations: 4
|
|
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
|
reward_weights: [1.0]
|
|
datasets:
|
|
- path: openai/gsm8k
|
|
name: main
|
|
type: rewards.oai_gsm8k_transform # format: '{file_name}.{fn_name}'
|
|
```
|
|
|
|
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
|
|
|
|
To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py).
|
|
|
|
#### GRPO with DAPO/Dr. GRPO loss
|
|
|
|
The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.
|
|
|
|
```yaml
|
|
trl:
|
|
loss_type: dr_grpo
|
|
# Normalizes loss based on max completion length (default: 256)
|
|
max_completion_length:
|
|
```
|
|
|
|
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
|
|
|
|
### SimPO
|
|
|
|
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
|
|
|
|
```yaml
|
|
rl: simpo
|
|
rl_beta: 0.1 # default in CPOTrainer
|
|
cpo_alpha: 1.0 # default in CPOTrainer
|
|
simpo_gamma: 0.5 # default in CPOTrainer
|
|
```
|
|
|
|
This method uses the same dataset format as [DPO](#dpo).
|
|
|
|
### Using local dataset files
|
|
|
|
```yaml
|
|
datasets:
|
|
- ds_type: json
|
|
data_files:
|
|
- orca_rlhf.jsonl
|
|
split: train
|
|
type: chatml.intel
|
|
```
|
|
|
|
### TRL auto-unwrapping for PEFT
|
|
|
|
TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:
|
|
|
|
```yaml
|
|
# load ref model when adapter training.
|
|
rl_adapter_ref_model: true
|
|
```
|