Feat(doc): Reorganize documentation, fix broken syntax, update notes (#2348)

* feat(doc): organize docs, add to menu bar, fix broken formatting

* feat: add link to custom integrations

* feat: update readme for integrations to include citations and repo link

* chore: update lm_eval info

* chore: use fullname

* Update docs/cli.qmd per suggestion

Co-authored-by: Dan Saunders <danjsaund@gmail.com>

* feat: add sweep doc

* feat: add kd doc

* fix: remove toc

* fix: update deprecation

* feat: add more info about chat_template issues

* fix: heading level

* fix: shell->bash code block

* fix: ray link

* fix(doc): heading level, header links, formatting

* feat: add grpo docs

* feat: add style changes

* fix: wrong cli arg for lm-eval

* fix: remove old run method

* feat: load custom integration doc dynamically

* fix: remove old cli way

* fix: toc

* fix: minor formatting

---------

Co-authored-by: Dan Saunders <danjsaund@gmail.com>
This commit is contained in:
NanoCode012
2025-02-25 16:09:37 +07:00
committed by GitHub
parent 1110a37e21
commit 2efe1b4c09
32 changed files with 940 additions and 443 deletions

View File

@@ -3,22 +3,22 @@ title: "RLHF (Beta)"
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
back-to-top-navigation: true
toc: true
toc-depth: 3
toc-depth: 4
---
# Overview
## Overview
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
feedback. Various methods include, but not limited to:
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
- [Direct Preference Optimization (DPO)](#dpo)
- [Identity Preference Optimization (IPO)](#ipo)
- [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
# RLHF using Axolotl
## RLHF using Axolotl
::: {.callout-important}
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
@@ -30,7 +30,7 @@ We rely on the [TRL](https://github.com/huggingface/trl) library for implementat
You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`.
:::
## DPO
### DPO
Example config:
@@ -47,7 +47,7 @@ datasets:
DPO supports the following types with the following dataset format:
### chatml.argilla
#### chatml.argilla
```json
{
@@ -58,7 +58,7 @@ DPO supports the following types with the following dataset format:
}
```
### chatml.argilla_chat
#### chatml.argilla_chat
```json
{
@@ -73,7 +73,7 @@ DPO supports the following types with the following dataset format:
}
```
### chatml.icr
#### chatml.icr
```json
{
@@ -84,7 +84,7 @@ DPO supports the following types with the following dataset format:
}
```
### chatml.intel
#### chatml.intel
```json
{
@@ -95,7 +95,7 @@ DPO supports the following types with the following dataset format:
}
```
### chatml.prompt_pairs
#### chatml.prompt_pairs
```json
{
@@ -106,7 +106,7 @@ DPO supports the following types with the following dataset format:
}
```
### chatml.ultra
#### chatml.ultra
```json
{
@@ -123,7 +123,7 @@ DPO supports the following types with the following dataset format:
}
```
### llama3.argilla
#### llama3.argilla
```json
{
@@ -134,7 +134,7 @@ DPO supports the following types with the following dataset format:
}
```
### llama3.argilla_chat
#### llama3.argilla_chat
```json
{
@@ -149,7 +149,7 @@ DPO supports the following types with the following dataset format:
}
```
### llama3.icr
#### llama3.icr
```json
{
@@ -160,7 +160,7 @@ DPO supports the following types with the following dataset format:
}
```
### llama3.intel
#### llama3.intel
```json
{
@@ -171,7 +171,7 @@ DPO supports the following types with the following dataset format:
}
```
### llama3.prompt_pairs
#### llama3.prompt_pairs
```json
{
@@ -182,7 +182,7 @@ DPO supports the following types with the following dataset format:
}
```
### llama3.ultra
#### llama3.ultra
```json
{
@@ -199,7 +199,7 @@ DPO supports the following types with the following dataset format:
}
```
### zephyr.nectar
#### zephyr.nectar
```json
{
@@ -218,7 +218,7 @@ DPO supports the following types with the following dataset format:
}
```
### chat_template.default
#### chat_template.default
```yaml
rl: dpo
@@ -264,7 +264,7 @@ Sample input format:
}
```
### user_defined.default
#### user_defined.default
For custom behaviors,
@@ -295,7 +295,7 @@ The input format is a simple JSON input with customizable fields based on the ab
}
```
## IPO
### IPO
As IPO is just DPO with a different loss function, all supported options for DPO works here.
@@ -303,7 +303,7 @@ As IPO is just DPO with a different loss function, all supported options for DPO
rl: ipo
```
## ORPO
### ORPO
Paper: https://arxiv.org/abs/2403.07691
@@ -320,7 +320,7 @@ datasets:
ORPO supports the following types with the following dataset format:
### chat_template.argilla
#### chat_template.argilla
```json
{
@@ -339,7 +339,7 @@ ORPO supports the following types with the following dataset format:
}
```
## KTO
### KTO
```yaml
rl: kto
@@ -360,7 +360,7 @@ gradient_checkpointing_kwargs:
KTO supports the following types with the following dataset format:
### chatml.argilla
#### chatml.argilla
```json
{
@@ -370,7 +370,7 @@ KTO supports the following types with the following dataset format:
}
```
### chatml.argilla_chat
#### chatml.argilla_chat
```json
{
@@ -383,7 +383,7 @@ KTO supports the following types with the following dataset format:
}
```
### chatml.intel
#### chatml.intel
```json
{
@@ -393,7 +393,7 @@ KTO supports the following types with the following dataset format:
}
```
### chatml.prompt_pairs
#### chatml.prompt_pairs
```json
{
@@ -403,7 +403,7 @@ KTO supports the following types with the following dataset format:
}
```
### chatml.ultra
#### chatml.ultra
```json
{
@@ -413,7 +413,7 @@ KTO supports the following types with the following dataset format:
}
```
### llama3.argilla
#### llama3.argilla
```json
{
@@ -423,7 +423,7 @@ KTO supports the following types with the following dataset format:
}
```
### llama3.argilla_chat
#### llama3.argilla_chat
```json
{
@@ -434,7 +434,7 @@ KTO supports the following types with the following dataset format:
}
```
### llama3.intel
#### llama3.intel
```json
{
@@ -444,7 +444,7 @@ KTO supports the following types with the following dataset format:
}
```
### llama3.prompt_pairs
#### llama3.prompt_pairs
```json
{
@@ -454,7 +454,7 @@ KTO supports the following types with the following dataset format:
}
```
### llama3.ultra
#### llama3.ultra
```json
{
@@ -464,7 +464,7 @@ KTO supports the following types with the following dataset format:
}
```
### user_defined.default
#### user_defined.default
For custom behaviors,
@@ -494,7 +494,49 @@ The input format is a simple JSON input with customizable fields based on the ab
}
```
## Using local dataset files
### GRPO
GRPO uses custom reward functions and transformations. Please have them ready locally.
For ex, to load OpenAI's GSM8K and use a random reward for completions:
```python
# rewards.py
import random
def rand_reward_func(completions, **kwargs) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
def oai_gsm8k_transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [{"role": "user", "content": example["question"]},],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}
```
```yaml
rl: grpo
trl:
beta: 0.001
max_completion_length: 256
use_vllm: True
vllm_device: auto
vllm_gpu_memory_utilization: 0.15
num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
datasets:
- path: openai/gsm8k
name: main
type: rewards.oai_gsm8k_transform # format: '{file_name}.{fn_name}'
```
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
### Using local dataset files
```yaml
datasets:
@@ -505,7 +547,7 @@ datasets:
type: chatml.intel
```
## TRL auto-unwrapping for PEFT
### TRL auto-unwrapping for PEFT
TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config: