* feat: add config for optional parameters in a chat message * chore: cleanup * chore: fix nits and add light docs * docs: update docs/dataset-formats/conversation.qmd Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * feat: configurable message mappings, jinja template analyzer * chore: handle bradley terry * docs: update docs * refactor: change order of mappings, improve message transform * refactor: make chat awware of property mappings * chore: remove .python-version * chore: revert change * chore: add dataset validation to tests where appropriate * chore: add dataset validation to tests where appropriate * chore: clean up handling of ds_cfg * chore: recursively serialize config * make sure to use the return value from validate_config * DefaultDict pickle/unpickle fix * fix super call for override * refactor: message fields * chore: empty commit * tests: validate config before using * chore: add config validation to all e2e tests * chore: add unneeded logging * chore: add missed config validation * chore: pass field_messages to prompter * test: fix borked test * chore: remove uninteded file * chore: add deprecation warning and update chat_datasets script * chore: lint * refactor: message fields * feat: update axolotlinputconfig and test_models - add configdict import in axolotl/utils/config/models/input/v0_4_1/__init__.py - remove unnecessary line breaks in sftdataset, dpodataset, ktodataset, stepwisesuperviseddataset classes - update model_dump method in axolotlinputconfig to exclude none values - correct typo in test_models.py comment * feat: simplify dpodataset and ktodataset classes in config models removed several optional fields from dpodataset and ktodataset classes in axolotl/utils/config/models/input/v0_4_1. this simplifies the configuration subsets for these datasets. * feat: improve readability and structure in dataset configuration models this commit enhances the readability and structure of the dataset configuration models in the `axolotl/utils/config/models/input/v0_4_1` module. it removes unused `configdict` import and adds line breaks to separate class definitions for better clarity. additionally, a minor documentation fix is included to ensure a newline at the end of the `stepwise_supervised.qmd` file. * feat: change log level from info to debug in chattemplatestrategy * feat(prompt_strategies): refactor chattemplateprompter and chattemplatestrategy - Make `chat_template` a required parameter in `ChatTemplatePrompter` constructor - Add default value for `message_property_mappings` in `ChatTemplatePrompter` constructor - Add `messages_array_name` property to `ChatTemplatePrompter` - Change `processor` type to Optional in `ChatTemplatePrompter` - Add TypeError check for `processor` in `ChatTemplatePrompter.build_prompt` - Remove `_messages` property from `ChatTemplateStrategy` - Make `prompter` a required parameter and add type hint in `ChatTemplateStrategy` constructor - Remove `messages` getter and setter from `ChatTemplateStrategy` - Use `prompter.messages_array_name` in `ChatTemplateStrategy.get_conversation_thread` - Remove condition to set `messages` field in `load` function * feat(tests/utils): ignore type check in load_model call in test_models.py * feat: improve type handling and test structure in chat templates - Add return type hint for `get_chat_template` function in `chat_templates.py` - Remove unnecessary assignment of `strategy.messages` in several test cases - Add `messages_array_name` parameter to various test configurations in `test_chat_templates.py` and `test_chat_templates_advanced.py` - Remove redundant `strategy.messages` assignment in `test_chat_templates_advanced.py` * feat(axolotl): enhance chat strategy with datasetconfig support This commit introduces support for DatasetConfig in the ChatTemplateStrategy. It also refines the strategy loader to handle different types of ds_cfg inputs and improves the clarity of the code by formatting and reordering. The key changes include: - Importing Union from typing and BaseModel from pydantic. - Adding DatasetConfig as an optional type for ds_cfg in StrategyLoader. - Adjusting the handling of ds_cfg in StrategyLoader to account for BaseModel instances. - Refactoring the prompter_params and strategy_params for better readability. - Changing the reference from prompt[self.messages] to prompt[self.prompter.messages_array_name] in the is_prompt_batched method. * feat: update message handling in btchattemplatestrategy * Replace `self.messages` with direct string references to "chosen_messages" and "rejected_messages" * Append system, user, and assistant content directly to "chosen_messages" and "rejected_messages" * Add a new attribute "messages_array_name" to the `load` function parameters * Remove the conditional attribute assignment for "field_messages" in the `load` function * feat: add config validation in test_kd.py - Import `validate_config` from `axolotl.utils.config` - Validate the configuration in `test_llama_kd` and another function in `TestKnowledgeDistillation` class * feat: enhance config validation and capabilities handling * Import `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals` * Update `validate_config` function to create `KTODataset` and `SFTDataset` instances using `dict(ds_cfg)` * Replace `capabilities` and `env_capabilities` with instances of `GPUCapabilities` and `EnvCapabilities` respectively in `AxolotlConfigWCapabilities` model dump * feat: update config validation in axolotl utils - Remove import of `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals` - Update `validate_config` function to use `capabilities` and `env_capabilities` directly instead of creating new instances of `GPUCapabilities` and `EnvCapabilities` * feat: refactor strategyloader in chat_template.py - Extracted the creation of strategy parameters into a separate function, `_get_strategy_params(cfg, dataset_config)` - Created a new function, `_get_strategy_cls()`, to obtain the strategy class - Replaced `ChatTemplateStrategy` with `strategy_cls` for strategy instantiation * trigger CI * chore: revert dataset config changes for kto/dpo * subject: refactor: rename 'messages_array_name' to 'field_messages' Body: - Renamed 'messages_array_name' to 'field_messages' in 'ChatTemplatePrompter' class and its usages in 'chat_template.py' - Updated 'load' function in 'bradley_terry/chat_template.py' to reflect the change - Adjusted 'get_chat_template_msg_variables' and 'get_message_vars' methods in 'jinja_template_analyzer.py' to use the new variable name - Modified 'StrategyLoader' in 'chat_template.py' to use 'field_messages' - Updated tests in 'test_chat_templates.py' and 'test_chat_templates_advanced.py' to use 'field_messages' instead of 'messages_array_name' * feat: refactor prompt strategies and update config models * Remove redundant 'return None' in `axolotl/prompt_strategies/__init__.py` * Simplify message handling in `axolotl/prompt_strategies/bradley_terry/chat_template.py` by using a single 'messages' list instead of separate 'chosen_messages' and 'rejected_messages' lists * Update default 'message_property_mappings' in `axolotl/prompt_strategies/bradley_terry/chat_template.py` * Add 'field_messages' field to `axolotl/utils/config/models/input/v0_4_1/__init__.py` configuration model * chore: remove unused input * chore: remove redundant type ignore * fix: remove old configs and update examples * fix: type check * fix: remove loading old config in ChatMessage * fix: update faq with potential new undefinederror * fix: add debug if property mapped is not found * chore: improve explanation for unmapped properties * fix: update docs with new config * chore: add note for deprecation config and del old config from dict --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: NanoCode012 <nano@axolotl.ai>
516 lines
9.1 KiB
Plaintext
516 lines
9.1 KiB
Plaintext
---
|
|
title: "RLHF (Beta)"
|
|
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
|
back-to-top-navigation: true
|
|
toc: true
|
|
toc-depth: 3
|
|
---
|
|
|
|
# Overview
|
|
|
|
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
|
|
feedback. Various methods include, but not limited to:
|
|
|
|
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
|
|
- [Direct Preference Optimization (DPO)](#dpo)
|
|
- [Identity Preference Optimization (IPO)](#ipo)
|
|
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
|
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
|
|
|
|
|
# RLHF using Axolotl
|
|
|
|
::: {.callout-important}
|
|
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
|
:::
|
|
|
|
We rely on the [TRL](https://github.com/huggingface/trl) library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.
|
|
|
|
::: {.callout-tip}
|
|
You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`.
|
|
:::
|
|
|
|
## DPO
|
|
|
|
Example config:
|
|
|
|
```yaml
|
|
rl: dpo
|
|
datasets:
|
|
- path: Intel/orca_dpo_pairs
|
|
split: train
|
|
type: chatml.intel
|
|
- path: argilla/ultrafeedback-binarized-preferences
|
|
split: train
|
|
type: chatml
|
|
```
|
|
|
|
DPO supports the following types with the following dataset format:
|
|
|
|
### chatml.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"chosen_response": "...",
|
|
"rejected_response": "..."
|
|
}
|
|
```
|
|
|
|
### chatml.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
### chatml.icr
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"input": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
### chatml.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
### chatml.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
### chatml.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
### llama3.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"chosen_response": "...",
|
|
"rejected_response": "..."
|
|
}
|
|
```
|
|
|
|
### llama3.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
### llama3.icr
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"input": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
### llama3.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
### llama3.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
### llama3.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
### zephyr.nectar
|
|
|
|
```json
|
|
{
|
|
"prompt": "...",
|
|
"answers": [
|
|
{
|
|
"answer": "...",
|
|
"rank": 1
|
|
},
|
|
{
|
|
"answer": "...",
|
|
"rank": 2
|
|
}
|
|
// ... more answers with ranks
|
|
]
|
|
}
|
|
```
|
|
|
|
### chat_template.default
|
|
|
|
```yaml
|
|
rl: dpo
|
|
datasets:
|
|
- path: ...
|
|
split: train
|
|
type: chat_template.default
|
|
field_messages: "messages"
|
|
field_chosen: "chosen"
|
|
field_rejected: "rejected"
|
|
message_property_mappings:
|
|
role: role
|
|
content: content
|
|
roles:
|
|
user: ["user"]
|
|
assistant: ["assistant"]
|
|
system: ["system"]
|
|
```
|
|
|
|
Sample input format:
|
|
|
|
```json
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": "..."
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "..."
|
|
},
|
|
// ... more messages
|
|
],
|
|
"chosen": {
|
|
"role": "assistant",
|
|
"content": "..."
|
|
},
|
|
"rejected": {
|
|
"role": "assistant",
|
|
"content": "..."
|
|
}
|
|
}
|
|
```
|
|
|
|
### user_defined.default
|
|
|
|
For custom behaviors,
|
|
|
|
```yaml
|
|
rl: dpo
|
|
datasets:
|
|
- path: ...
|
|
split: train
|
|
type: user_defined.default
|
|
|
|
field_prompt: "prompt"
|
|
field_system: "system"
|
|
field_chosen: "chosen"
|
|
field_rejected: "rejected"
|
|
prompt_format: "{prompt}"
|
|
chosen_format: "{chosen}"
|
|
rejected_format: "{rejected}"
|
|
```
|
|
|
|
The input format is a simple JSON input with customizable fields based on the above config.
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
## IPO
|
|
|
|
As IPO is just DPO with a different loss function, all supported options for DPO works here.
|
|
|
|
```yaml
|
|
rl: ipo
|
|
```
|
|
|
|
## ORPO
|
|
|
|
Paper: https://arxiv.org/abs/2403.07691
|
|
|
|
```yaml
|
|
rl: orpo
|
|
orpo_alpha: 0.1
|
|
remove_unused_columns: false
|
|
|
|
chat_template: chatml
|
|
datasets:
|
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
|
type: chat_template.argilla
|
|
```
|
|
|
|
ORPO supports the following types with the following dataset format:
|
|
|
|
### chat_template.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...", // if available, will be taken as user message for single-turn instead of from list below
|
|
|
|
// chosen/rejected should be same till last content and only even-number of alternating user/assistant turns
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
## KTO
|
|
|
|
```yaml
|
|
rl: kto
|
|
rl_beta: 0.5
|
|
kto_desirable_weight: 0.2
|
|
|
|
remove_unused_columns: false
|
|
|
|
datasets:
|
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
|
|
type: llama3.ultra
|
|
split: train
|
|
|
|
gradient_checkpointing: true
|
|
gradient_checkpointing_kwargs:
|
|
use_reentrant: true
|
|
```
|
|
|
|
KTO supports the following types with the following dataset format:
|
|
|
|
### chatml.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
### chatml.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"chosen": [
|
|
{"role": "user", "content": "..."}
|
|
],
|
|
"completion": [
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
### chatml.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
### chatml.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
### chatml.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
### llama3.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
### llama3.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"completion": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
### llama3.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
### llama3.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
### llama3.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
### user_defined.default
|
|
|
|
For custom behaviors,
|
|
|
|
```yaml
|
|
rl: kto
|
|
datasets:
|
|
- path: ...
|
|
split: train
|
|
type: user_defined.default
|
|
|
|
field_prompt: "prompt"
|
|
field_system: "system"
|
|
field_completion: "completion"
|
|
field_label: "label"
|
|
prompt_format: "{prompt}"
|
|
completion_format: "{completion}"
|
|
```
|
|
|
|
The input format is a simple JSON input with customizable fields based on the above config.
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "...",
|
|
"label": "..."
|
|
}
|
|
```
|
|
|
|
## Using local dataset files
|
|
|
|
```yaml
|
|
datasets:
|
|
- ds_type: json
|
|
data_files:
|
|
- orca_rlhf.jsonl
|
|
split: train
|
|
type: chatml.intel
|
|
```
|
|
|
|
## TRL auto-unwrapping for PEFT
|
|
|
|
TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:
|
|
|
|
```yaml
|
|
# load ref model when adapter training.
|
|
rl_adapter_ref_model: true
|
|
```
|