feat(doc): Improve guide to dataset types with better examples (#2286)

This commit is contained in:
NanoCode012
2025-02-14 04:01:41 +07:00
committed by GitHub
parent ffae8d6a95
commit 40362d60e0
3 changed files with 889 additions and 25 deletions

View File

@@ -1,26 +1,39 @@
---
title: "RLHF (Beta)"
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
back-to-top-navigation: true
toc: true
toc-depth: 3
---
### Overview
# Overview
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
feedback. Various methods include, but not limited to:
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
- Direct Preference Optimization (DPO)
- Identity Preference Optimization (IPO)
- [Direct Preference Optimization (DPO)](#dpo)
- [Identity Preference Optimization (IPO)](#ipo)
- [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
### RLHF using Axolotl
# RLHF using Axolotl
>[!IMPORTANT]
>This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
::: {.callout-important}
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
:::
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML
We rely on the [TRL](https://github.com/huggingface/trl) library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.
::: {.callout-tip}
You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`.
:::
## DPO
Example config:
#### DPO
```yaml
rl: dpo
datasets:
@@ -32,12 +45,264 @@ datasets:
type: chatml
```
#### IPO
DPO supports the following types with the following dataset format:
### chatml.argilla
```json
{
"system": "...", // optional
"instruction": "...",
"chosen_response": "...",
"rejected_response": "..."
}
```
### chatml.argilla_chat
```json
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
### chatml.icr
```json
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
```
### chatml.intel
```json
{
"system": "...", // optional
"question": "...",
"chosen": "...",
"rejected": "..."
}
```
### chatml.prompt_pairs
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": "...",
"rejected": "..."
}
```
### chatml.ultra
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
### llama3.argilla
```json
{
"system": "...", // optional
"instruction": "...",
"chosen_response": "...",
"rejected_response": "..."
}
```
### llama3.argilla_chat
```json
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
### llama3.icr
```json
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
```
### llama3.intel
```json
{
"system": "...", // optional
"question": "...",
"chosen": "...",
"rejected": "..."
}
```
### llama3.prompt_pairs
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": "...",
"rejected": "..."
}
```
### llama3.ultra
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
### zephyr.nectar
```json
{
"prompt": "...",
"answers": [
{
"answer": "...",
"rank": 1
},
{
"answer": "...",
"rank": 2
}
// ... more answers with ranks
]
}
```
### chat_template.default
```yaml
rl: dpo
datasets:
- path: ...
split: train
type: chat_template.default
field_messages: "messages"
field_chosen: "chosen"
field_rejected: "rejected"
message_field_role: "role"
message_field_content: "content"
roles:
user: ["user"]
assistant: ["assistant"]
system: ["system"]
```
Sample input format:
```json
{
"messages": [
{
"role": "system",
"content": "..."
},
{
"role": "user",
"content": "..."
},
// ... more messages
],
"chosen": {
"role": "assistant",
"content": "..."
},
"rejected": {
"role": "assistant",
"content": "..."
}
}
```
### user_defined.default
For custom behaviors,
```yaml
rl: dpo
datasets:
- path: ...
split: train
type: user_defined.default
field_prompt: "prompt"
field_system: "system"
field_chosen: "chosen"
field_rejected: "rejected"
prompt_format: "{prompt}"
chosen_format: "{chosen}"
rejected_format: "{rejected}"
```
The input format is a simple JSON input with customizable fields based on the above config.
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": "...",
"rejected": "..."
}
```
## IPO
As IPO is just DPO with a different loss function, all supported options for DPO works here.
```yaml
rl: ipo
```
#### ORPO
## ORPO
Paper: https://arxiv.org/abs/2403.07691
@@ -52,8 +317,28 @@ datasets:
type: chat_template.argilla
```
ORPO supports the following types with the following dataset format:
#### KTO
### chat_template.argilla
```json
{
"system": "...", // optional
"prompt": "...", // if available, will be taken as user message for single-turn instead of from list below
// chosen/rejected should be same till last content and only even-number of alternating user/assistant turns
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
## KTO
```yaml
rl: kto
@@ -72,7 +357,144 @@ gradient_checkpointing_kwargs:
use_reentrant: true
```
#### Using local dataset files
KTO supports the following types with the following dataset format:
### chatml.argilla
```json
{
"system": "...", // optional
"instruction": "...",
"completion": "..."
}
```
### chatml.argilla_chat
```json
{
"chosen": [
{"role": "user", "content": "..."}
],
"completion": [
{"role": "assistant", "content": "..."}
]
}
```
### chatml.intel
```json
{
"system": "...", // optional
"question": "...",
"completion": "..."
}
```
### chatml.prompt_pairs
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
```
### chatml.ultra
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
```
### llama3.argilla
```json
{
"system": "...", // optional
"instruction": "...",
"completion": "..."
}
```
### llama3.argilla_chat
```json
{
"completion": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
### llama3.intel
```json
{
"system": "...", // optional
"question": "...",
"completion": "..."
}
```
### llama3.prompt_pairs
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
```
### llama3.ultra
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
```
### user_defined.default
For custom behaviors,
```yaml
rl: kto
datasets:
- path: ...
split: train
type: user_defined.default
field_prompt: "prompt"
field_system: "system"
field_completion: "completion"
field_label: "label"
prompt_format: "{prompt}"
completion_format: "{completion}"
```
The input format is a simple JSON input with customizable fields based on the above config.
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "...",
"label": "..."
}
```
## Using local dataset files
```yaml
datasets:
- ds_type: json
@@ -82,9 +504,9 @@ datasets:
type: chatml.intel
```
#### Trl autounwrap for peft
## TRL auto-unwrapping for PEFT
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:
```yaml
# load ref model when adapter training.