* refactor trainer to prevent circular dependencies later fix loader default KD dataset loading and KD with logprobs filter bad rows make batch smaller handle padding/collation for KD datasets make it work flipped the slice cross entropy loss coefficient during KD make sure to multiply against the correct loss chore: lint triton wip no where support v2 trial no torch.exp inside triton kernel no log etc no torch.tensor v3 fix kwarg don't use triton for now better rescaling for temperatures hash for temperature too use kd_alpha in the correct loss method fix kd loss so it's causal (fixes repeating tokens) var naming and add todo chore: lint refactor so we can easily add new loss functions add license block remove references to triton kd for now handle token/logprob shifting support for custom trainer classes from plugins refactor kd chat template loader move more things to kd plugin remove moved class from import make plugin setup concise increase logging around loading plugins add copyrights remove duplicate code more info on preprocess for kd and fix import be a bit pickier about loading dynamic prompt strategies kd sample packing make loss torch script compat support streaming for processing sft datasts? improve iterable support ensure that batch vs single is done properly tweak check for batched prompt data reward can use same batch check fix reward trainer calls for tokenization improve check for batched reward model doesn't work well with batched add kd trainer e2e test linting rename test files so it gets picked up make the kd e2e fit in vram for ci and add lora version set lora_dropout explicitly lower lr make sure to set tokenizer from l3 70b and save safetensors make sure to use the correct tokenizer fix adapter model check make sure to use tensorboard to capture loss for checks chore: lint chore: lint improve logprob masking and shift in trainer more fixes try tests for kd on l40s don't shift student logits for kd no batching for kd chat templates make sure to truncate logprobs if there are more than top_k change up logic so we always truncate to top_k use iter instead of tuple fix finding the top-k rather than assuming first position has the correct val apply z-score scaling to kd kd loss needs to be calculated in full precision Always re-normalize teacher distribution various fixes * support for configurable top-k/softmax ordering * add attribute check for filter rows and lint * fix logic * handle none case for conversion to int * fix student logit off by one * set kd_temp to 1.0 for test loss * address PR feedback
93 lines
2.1 KiB
Plaintext
93 lines
2.1 KiB
Plaintext
---
|
|
title: "RLHF (Beta)"
|
|
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
|
---
|
|
|
|
### Overview
|
|
|
|
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
|
|
feedback. Various methods include, but not limited to:
|
|
|
|
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
|
|
- Direct Preference Optimization (DPO)
|
|
- Identity Preference Optimization (IPO)
|
|
|
|
|
|
### RLHF using Axolotl
|
|
|
|
>[!IMPORTANT]
|
|
>This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
|
|
|
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML
|
|
|
|
#### DPO
|
|
```yaml
|
|
rl: dpo
|
|
datasets:
|
|
- path: Intel/orca_dpo_pairs
|
|
split: train
|
|
type: chatml.intel
|
|
- path: argilla/ultrafeedback-binarized-preferences
|
|
split: train
|
|
type: chatml
|
|
```
|
|
|
|
#### IPO
|
|
```yaml
|
|
rl: ipo
|
|
```
|
|
|
|
#### ORPO
|
|
|
|
Paper: https://arxiv.org/abs/2403.07691
|
|
|
|
```yaml
|
|
rl: orpo
|
|
orpo_alpha: 0.1
|
|
remove_unused_columns: false
|
|
|
|
chat_template: chatml
|
|
datasets:
|
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
|
type: chat_template.argilla
|
|
```
|
|
|
|
|
|
#### KTO
|
|
|
|
```yaml
|
|
rl: kto
|
|
rl_beta: 0.5
|
|
kto_desirable_weight: 0.2
|
|
|
|
remove_unused_columns: false
|
|
|
|
datasets:
|
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
|
|
type: llama3.ultra
|
|
split: train
|
|
|
|
gradient_checkpointing: true
|
|
gradient_checkpointing_kwargs:
|
|
use_reentrant: true
|
|
```
|
|
|
|
#### Using local dataset files
|
|
```yaml
|
|
datasets:
|
|
- ds_type: json
|
|
data_files:
|
|
- orca_rlhf.jsonl
|
|
split: train
|
|
type: chatml.intel
|
|
```
|
|
|
|
#### Trl autounwrap for peft
|
|
|
|
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
|
|
|
|
```yaml
|
|
# load ref model when adapter training.
|
|
rl_adapter_ref_model: true
|
|
```
|