Files
axolotl/docs/rlhf.qmd

515 lines
9.1 KiB
Plaintext

---
title: "RLHF (Beta)"
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
back-to-top-navigation: true
toc: true
toc-depth: 3
---
# Overview
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
feedback. Various methods include, but not limited to:
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
- [Direct Preference Optimization (DPO)](#dpo)
- [Identity Preference Optimization (IPO)](#ipo)
- [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
# RLHF using Axolotl
::: {.callout-important}
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
:::
We rely on the [TRL](https://github.com/huggingface/trl) library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.
::: {.callout-tip}
You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`.
:::
## DPO
Example config:
```yaml
rl: dpo
datasets:
- path: Intel/orca_dpo_pairs
split: train
type: chatml.intel
- path: argilla/ultrafeedback-binarized-preferences
split: train
type: chatml
```
DPO supports the following types with the following dataset format:
### chatml.argilla
```json
{
"system": "...", // optional
"instruction": "...",
"chosen_response": "...",
"rejected_response": "..."
}
```
### chatml.argilla_chat
```json
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
### chatml.icr
```json
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
```
### chatml.intel
```json
{
"system": "...", // optional
"question": "...",
"chosen": "...",
"rejected": "..."
}
```
### chatml.prompt_pairs
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": "...",
"rejected": "..."
}
```
### chatml.ultra
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
### llama3.argilla
```json
{
"system": "...", // optional
"instruction": "...",
"chosen_response": "...",
"rejected_response": "..."
}
```
### llama3.argilla_chat
```json
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
### llama3.icr
```json
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
```
### llama3.intel
```json
{
"system": "...", // optional
"question": "...",
"chosen": "...",
"rejected": "..."
}
```
### llama3.prompt_pairs
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": "...",
"rejected": "..."
}
```
### llama3.ultra
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
### zephyr.nectar
```json
{
"prompt": "...",
"answers": [
{
"answer": "...",
"rank": 1
},
{
"answer": "...",
"rank": 2
}
// ... more answers with ranks
]
}
```
### chat_template.default
```yaml
rl: dpo
datasets:
- path: ...
split: train
type: chat_template.default
field_messages: "messages"
field_chosen: "chosen"
field_rejected: "rejected"
message_field_role: "role"
message_field_content: "content"
roles:
user: ["user"]
assistant: ["assistant"]
system: ["system"]
```
Sample input format:
```json
{
"messages": [
{
"role": "system",
"content": "..."
},
{
"role": "user",
"content": "..."
},
// ... more messages
],
"chosen": {
"role": "assistant",
"content": "..."
},
"rejected": {
"role": "assistant",
"content": "..."
}
}
```
### user_defined.default
For custom behaviors,
```yaml
rl: dpo
datasets:
- path: ...
split: train
type: user_defined.default
field_prompt: "prompt"
field_system: "system"
field_chosen: "chosen"
field_rejected: "rejected"
prompt_format: "{prompt}"
chosen_format: "{chosen}"
rejected_format: "{rejected}"
```
The input format is a simple JSON input with customizable fields based on the above config.
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": "...",
"rejected": "..."
}
```
## IPO
As IPO is just DPO with a different loss function, all supported options for DPO works here.
```yaml
rl: ipo
```
## ORPO
Paper: https://arxiv.org/abs/2403.07691
```yaml
rl: orpo
orpo_alpha: 0.1
remove_unused_columns: false
chat_template: chatml
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned
type: chat_template.argilla
```
ORPO supports the following types with the following dataset format:
### chat_template.argilla
```json
{
"system": "...", // optional
"prompt": "...", // if available, will be taken as user message for single-turn instead of from list below
// chosen/rejected should be same till last content and only even-number of alternating user/assistant turns
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
## KTO
```yaml
rl: kto
rl_beta: 0.5
kto_desirable_weight: 0.2
remove_unused_columns: false
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
type: llama3.ultra
split: train
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
```
KTO supports the following types with the following dataset format:
### chatml.argilla
```json
{
"system": "...", // optional
"instruction": "...",
"completion": "..."
}
```
### chatml.argilla_chat
```json
{
"chosen": [
{"role": "user", "content": "..."}
],
"completion": [
{"role": "assistant", "content": "..."}
]
}
```
### chatml.intel
```json
{
"system": "...", // optional
"question": "...",
"completion": "..."
}
```
### chatml.prompt_pairs
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
```
### chatml.ultra
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
```
### llama3.argilla
```json
{
"system": "...", // optional
"instruction": "...",
"completion": "..."
}
```
### llama3.argilla_chat
```json
{
"completion": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
### llama3.intel
```json
{
"system": "...", // optional
"question": "...",
"completion": "..."
}
```
### llama3.prompt_pairs
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
```
### llama3.ultra
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
```
### user_defined.default
For custom behaviors,
```yaml
rl: kto
datasets:
- path: ...
split: train
type: user_defined.default
field_prompt: "prompt"
field_system: "system"
field_completion: "completion"
field_label: "label"
prompt_format: "{prompt}"
completion_format: "{completion}"
```
The input format is a simple JSON input with customizable fields based on the above config.
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "...",
"label": "..."
}
```
## Using local dataset files
```yaml
datasets:
- ds_type: json
data_files:
- orca_rlhf.jsonl
split: train
type: chatml.intel
```
## TRL auto-unwrapping for PEFT
TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:
```yaml
# load ref model when adapter training.
rl_adapter_ref_model: true
```