515 lines
9.1 KiB
Plaintext
515 lines
9.1 KiB
Plaintext
---
|
|
title: "RLHF (Beta)"
|
|
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
|
back-to-top-navigation: true
|
|
toc: true
|
|
toc-depth: 3
|
|
---
|
|
|
|
# Overview
|
|
|
|
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
|
|
feedback. Various methods include, but not limited to:
|
|
|
|
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
|
|
- [Direct Preference Optimization (DPO)](#dpo)
|
|
- [Identity Preference Optimization (IPO)](#ipo)
|
|
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
|
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
|
|
|
|
|
# RLHF using Axolotl
|
|
|
|
::: {.callout-important}
|
|
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
|
:::
|
|
|
|
We rely on the [TRL](https://github.com/huggingface/trl) library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.
|
|
|
|
::: {.callout-tip}
|
|
You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`.
|
|
:::
|
|
|
|
## DPO
|
|
|
|
Example config:
|
|
|
|
```yaml
|
|
rl: dpo
|
|
datasets:
|
|
- path: Intel/orca_dpo_pairs
|
|
split: train
|
|
type: chatml.intel
|
|
- path: argilla/ultrafeedback-binarized-preferences
|
|
split: train
|
|
type: chatml
|
|
```
|
|
|
|
DPO supports the following types with the following dataset format:
|
|
|
|
### chatml.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"chosen_response": "...",
|
|
"rejected_response": "..."
|
|
}
|
|
```
|
|
|
|
### chatml.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
### chatml.icr
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"input": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
### chatml.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
### chatml.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
### chatml.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
### llama3.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"chosen_response": "...",
|
|
"rejected_response": "..."
|
|
}
|
|
```
|
|
|
|
### llama3.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
### llama3.icr
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"input": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
### llama3.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
### llama3.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
### llama3.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
### zephyr.nectar
|
|
|
|
```json
|
|
{
|
|
"prompt": "...",
|
|
"answers": [
|
|
{
|
|
"answer": "...",
|
|
"rank": 1
|
|
},
|
|
{
|
|
"answer": "...",
|
|
"rank": 2
|
|
}
|
|
// ... more answers with ranks
|
|
]
|
|
}
|
|
```
|
|
|
|
### chat_template.default
|
|
|
|
```yaml
|
|
rl: dpo
|
|
datasets:
|
|
- path: ...
|
|
split: train
|
|
type: chat_template.default
|
|
field_messages: "messages"
|
|
field_chosen: "chosen"
|
|
field_rejected: "rejected"
|
|
message_field_role: "role"
|
|
message_field_content: "content"
|
|
roles:
|
|
user: ["user"]
|
|
assistant: ["assistant"]
|
|
system: ["system"]
|
|
```
|
|
|
|
Sample input format:
|
|
|
|
```json
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": "..."
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "..."
|
|
},
|
|
// ... more messages
|
|
],
|
|
"chosen": {
|
|
"role": "assistant",
|
|
"content": "..."
|
|
},
|
|
"rejected": {
|
|
"role": "assistant",
|
|
"content": "..."
|
|
}
|
|
}
|
|
```
|
|
|
|
### user_defined.default
|
|
|
|
For custom behaviors,
|
|
|
|
```yaml
|
|
rl: dpo
|
|
datasets:
|
|
- path: ...
|
|
split: train
|
|
type: user_defined.default
|
|
|
|
field_prompt: "prompt"
|
|
field_system: "system"
|
|
field_chosen: "chosen"
|
|
field_rejected: "rejected"
|
|
prompt_format: "{prompt}"
|
|
chosen_format: "{chosen}"
|
|
rejected_format: "{rejected}"
|
|
```
|
|
|
|
The input format is a simple JSON input with customizable fields based on the above config.
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
## IPO
|
|
|
|
As IPO is just DPO with a different loss function, all supported options for DPO works here.
|
|
|
|
```yaml
|
|
rl: ipo
|
|
```
|
|
|
|
## ORPO
|
|
|
|
Paper: https://arxiv.org/abs/2403.07691
|
|
|
|
```yaml
|
|
rl: orpo
|
|
orpo_alpha: 0.1
|
|
remove_unused_columns: false
|
|
|
|
chat_template: chatml
|
|
datasets:
|
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
|
type: chat_template.argilla
|
|
```
|
|
|
|
ORPO supports the following types with the following dataset format:
|
|
|
|
### chat_template.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...", // if available, will be taken as user message for single-turn instead of from list below
|
|
|
|
// chosen/rejected should be same till last content and only even-number of alternating user/assistant turns
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
## KTO
|
|
|
|
```yaml
|
|
rl: kto
|
|
rl_beta: 0.5
|
|
kto_desirable_weight: 0.2
|
|
|
|
remove_unused_columns: false
|
|
|
|
datasets:
|
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
|
|
type: llama3.ultra
|
|
split: train
|
|
|
|
gradient_checkpointing: true
|
|
gradient_checkpointing_kwargs:
|
|
use_reentrant: true
|
|
```
|
|
|
|
KTO supports the following types with the following dataset format:
|
|
|
|
### chatml.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
### chatml.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"chosen": [
|
|
{"role": "user", "content": "..."}
|
|
],
|
|
"completion": [
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
### chatml.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
### chatml.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
### chatml.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
### llama3.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
### llama3.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"completion": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
### llama3.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
### llama3.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
### llama3.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
### user_defined.default
|
|
|
|
For custom behaviors,
|
|
|
|
```yaml
|
|
rl: kto
|
|
datasets:
|
|
- path: ...
|
|
split: train
|
|
type: user_defined.default
|
|
|
|
field_prompt: "prompt"
|
|
field_system: "system"
|
|
field_completion: "completion"
|
|
field_label: "label"
|
|
prompt_format: "{prompt}"
|
|
completion_format: "{completion}"
|
|
```
|
|
|
|
The input format is a simple JSON input with customizable fields based on the above config.
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "...",
|
|
"label": "..."
|
|
}
|
|
```
|
|
|
|
## Using local dataset files
|
|
|
|
```yaml
|
|
datasets:
|
|
- ds_type: json
|
|
data_files:
|
|
- orca_rlhf.jsonl
|
|
split: train
|
|
type: chatml.intel
|
|
```
|
|
|
|
## TRL auto-unwrapping for PEFT
|
|
|
|
TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:
|
|
|
|
```yaml
|
|
# load ref model when adapter training.
|
|
rl_adapter_ref_model: true
|
|
```
|