Feat(doc): Reorganize documentation, fix broken syntax, update notes (#2348)
* feat(doc): organize docs, add to menu bar, fix broken formatting * feat: add link to custom integrations * feat: update readme for integrations to include citations and repo link * chore: update lm_eval info * chore: use fullname * Update docs/cli.qmd per suggestion Co-authored-by: Dan Saunders <danjsaund@gmail.com> * feat: add sweep doc * feat: add kd doc * fix: remove toc * fix: update deprecation * feat: add more info about chat_template issues * fix: heading level * fix: shell->bash code block * fix: ray link * fix(doc): heading level, header links, formatting * feat: add grpo docs * feat: add style changes * fix: wrong cli arg for lm-eval * fix: remove old run method * feat: load custom integration doc dynamically * fix: remove old cli way * fix: toc * fix: minor formatting --------- Co-authored-by: Dan Saunders <danjsaund@gmail.com>
This commit is contained in:
116
docs/rlhf.qmd
116
docs/rlhf.qmd
@@ -3,22 +3,22 @@ title: "RLHF (Beta)"
|
||||
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
toc-depth: 4
|
||||
---
|
||||
|
||||
# Overview
|
||||
## Overview
|
||||
|
||||
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
|
||||
feedback. Various methods include, but not limited to:
|
||||
|
||||
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
|
||||
- [Direct Preference Optimization (DPO)](#dpo)
|
||||
- [Identity Preference Optimization (IPO)](#ipo)
|
||||
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
||||
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
||||
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
|
||||
|
||||
|
||||
# RLHF using Axolotl
|
||||
## RLHF using Axolotl
|
||||
|
||||
::: {.callout-important}
|
||||
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
||||
@@ -30,7 +30,7 @@ We rely on the [TRL](https://github.com/huggingface/trl) library for implementat
|
||||
You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`.
|
||||
:::
|
||||
|
||||
## DPO
|
||||
### DPO
|
||||
|
||||
Example config:
|
||||
|
||||
@@ -47,7 +47,7 @@ datasets:
|
||||
|
||||
DPO supports the following types with the following dataset format:
|
||||
|
||||
### chatml.argilla
|
||||
#### chatml.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -58,7 +58,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.argilla_chat
|
||||
#### chatml.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -73,7 +73,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.icr
|
||||
#### chatml.icr
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -84,7 +84,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.intel
|
||||
#### chatml.intel
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -95,7 +95,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.prompt_pairs
|
||||
#### chatml.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -106,7 +106,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.ultra
|
||||
#### chatml.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -123,7 +123,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla
|
||||
#### llama3.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -134,7 +134,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla_chat
|
||||
#### llama3.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -149,7 +149,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.icr
|
||||
#### llama3.icr
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -160,7 +160,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.intel
|
||||
#### llama3.intel
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -171,7 +171,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.prompt_pairs
|
||||
#### llama3.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -182,7 +182,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.ultra
|
||||
#### llama3.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -199,7 +199,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### zephyr.nectar
|
||||
#### zephyr.nectar
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -218,7 +218,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chat_template.default
|
||||
#### chat_template.default
|
||||
|
||||
```yaml
|
||||
rl: dpo
|
||||
@@ -264,7 +264,7 @@ Sample input format:
|
||||
}
|
||||
```
|
||||
|
||||
### user_defined.default
|
||||
#### user_defined.default
|
||||
|
||||
For custom behaviors,
|
||||
|
||||
@@ -295,7 +295,7 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
}
|
||||
```
|
||||
|
||||
## IPO
|
||||
### IPO
|
||||
|
||||
As IPO is just DPO with a different loss function, all supported options for DPO works here.
|
||||
|
||||
@@ -303,7 +303,7 @@ As IPO is just DPO with a different loss function, all supported options for DPO
|
||||
rl: ipo
|
||||
```
|
||||
|
||||
## ORPO
|
||||
### ORPO
|
||||
|
||||
Paper: https://arxiv.org/abs/2403.07691
|
||||
|
||||
@@ -320,7 +320,7 @@ datasets:
|
||||
|
||||
ORPO supports the following types with the following dataset format:
|
||||
|
||||
### chat_template.argilla
|
||||
#### chat_template.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -339,7 +339,7 @@ ORPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
## KTO
|
||||
### KTO
|
||||
|
||||
```yaml
|
||||
rl: kto
|
||||
@@ -360,7 +360,7 @@ gradient_checkpointing_kwargs:
|
||||
|
||||
KTO supports the following types with the following dataset format:
|
||||
|
||||
### chatml.argilla
|
||||
#### chatml.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -370,7 +370,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.argilla_chat
|
||||
#### chatml.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -383,7 +383,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.intel
|
||||
#### chatml.intel
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -393,7 +393,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.prompt_pairs
|
||||
#### chatml.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -403,7 +403,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.ultra
|
||||
#### chatml.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -413,7 +413,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla
|
||||
#### llama3.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -423,7 +423,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla_chat
|
||||
#### llama3.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -434,7 +434,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.intel
|
||||
#### llama3.intel
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -444,7 +444,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.prompt_pairs
|
||||
#### llama3.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -454,7 +454,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.ultra
|
||||
#### llama3.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -464,7 +464,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### user_defined.default
|
||||
#### user_defined.default
|
||||
|
||||
For custom behaviors,
|
||||
|
||||
@@ -494,7 +494,49 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
}
|
||||
```
|
||||
|
||||
## Using local dataset files
|
||||
### GRPO
|
||||
|
||||
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
||||
|
||||
For ex, to load OpenAI's GSM8K and use a random reward for completions:
|
||||
|
||||
```python
|
||||
# rewards.py
|
||||
import random
|
||||
|
||||
def rand_reward_func(completions, **kwargs) -> list[float]:
|
||||
return [random.uniform(0, 1) for _ in completions]
|
||||
|
||||
def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
def transform_fn(example, tokenizer=None):
|
||||
label = example["answer"].split("####")[-1].strip().replace(",", "")
|
||||
return {
|
||||
"prompt": [{"role": "user", "content": example["question"]},],
|
||||
"answer": label,
|
||||
}
|
||||
return transform_fn, {"remove_columns": ["question"]}
|
||||
```
|
||||
|
||||
```yaml
|
||||
rl: grpo
|
||||
|
||||
trl:
|
||||
beta: 0.001
|
||||
max_completion_length: 256
|
||||
use_vllm: True
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.15
|
||||
num_generations: 4
|
||||
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
||||
datasets:
|
||||
- path: openai/gsm8k
|
||||
name: main
|
||||
type: rewards.oai_gsm8k_transform # format: '{file_name}.{fn_name}'
|
||||
```
|
||||
|
||||
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
|
||||
|
||||
### Using local dataset files
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
@@ -505,7 +547,7 @@ datasets:
|
||||
type: chatml.intel
|
||||
```
|
||||
|
||||
## TRL auto-unwrapping for PEFT
|
||||
### TRL auto-unwrapping for PEFT
|
||||
|
||||
TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user