* docs: comprehensive documentation improvements for humans and agents New human docs: - grpo.qmd: GRPO deep dive (async, rewards, IS correction, scaling) - ebft.qmd: EBFT guide (structured/strided modes, feature extraction) - choosing_method.qmd: decision tree for SFT vs LoRA vs DPO vs GRPO - vllm_serving.qmd: vLLM setup for GRPO (server/colocate, LoRA sync) - training_stability.qmd: monitoring, NaN debugging, OOM, healthy metrics New agent docs: - AGENTS_SFT.md: agent reference for supervised fine-tuning - AGENTS_DPO.md: agent reference for preference learning (DPO/KTO/ORPO) Updated existing docs: - rlhf.qmd: cross-references to new GRPO/EBFT/choosing-method guides - getting-started.qmd: reorganized Next Steps with links to new guides - debugging.qmd: link to training stability guide - _quarto.yml: added new pages to sidebar navigation Removed: - bak.agents.md: stale backup that confused agents * docs: trim duplicated generic config from AGENTS_DPO.md Remove boilerplate training params (optimizer, gradient_checkpointing, flash_attention, etc.) from each method template. These are not preference-learning-specific and are already covered in AGENTS_SFT.md. Config templates now show only method-specific fields with a reference to AGENTS_SFT.md for the rest. * docs: deduplicate across new doc pages - grpo.qmd: collapse vLLM setup section to brief config + link to vllm_serving.qmd; collapse IS correction to essentials + link; replace full monitoring tables with summary + link to training_stability.qmd - vllm_serving.qmd: remove duplicated async/IS config reference tables (already in grpo.qmd config reference); replace full example config with link to grpo.qmd quick start - ebft.qmd: trim generic training params in quick start config * fix: train scripts * feat: split files into cleaner parts * fix: cleanup pretraining docs --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
1363 lines
36 KiB
Plaintext
1363 lines
36 KiB
Plaintext
---
|
|
title: "RLHF (Beta)"
|
|
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
|
back-to-top-navigation: true
|
|
toc: true
|
|
toc-expand: 2
|
|
toc-depth: 4
|
|
---
|
|
|
|
## Overview
|
|
|
|
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
|
|
feedback. Various methods include, but not limited to:
|
|
|
|
- [Direct Preference Optimization (DPO)](#dpo)
|
|
- [Identity Preference Optimization (IPO)](#ipo)
|
|
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
|
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
|
- [Group Relative Policy Optimization (GRPO)](#grpo) — see also the [GRPO deep dive](grpo.qmd) for async features, custom rewards, and scaling
|
|
- [Group Reward-Decoupled Policy Optimization (GDPO)](#gdpo)
|
|
- [Energy-Based Fine-Tuning (EBFT)](#ebft) — see also the [EBFT guide](ebft.qmd) for detailed mode comparisons and configuration
|
|
- [NeMo Gym Integration](#nemo-gym-integration)
|
|
|
|
For help choosing between these methods, see [Choosing a Fine-Tuning Method](choosing_method.qmd).
|
|
|
|
|
|
## RLHF using Axolotl
|
|
|
|
::: {.callout-important}
|
|
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
|
:::
|
|
|
|
We rely on the [TRL](https://github.com/huggingface/trl) library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.
|
|
|
|
::: {.callout-tip}
|
|
You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`.
|
|
:::
|
|
|
|
### DPO
|
|
|
|
Example config:
|
|
|
|
```yaml
|
|
rl: dpo
|
|
datasets:
|
|
- path: Intel/orca_dpo_pairs
|
|
split: train
|
|
type: chatml.intel
|
|
- path: argilla/ultrafeedback-binarized-preferences
|
|
split: train
|
|
type: chatml
|
|
```
|
|
|
|
DPO supports the following types with the following dataset format:
|
|
|
|
#### chatml.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"chosen_response": "...",
|
|
"rejected_response": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### chatml.icr
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"input": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### llama3.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"chosen_response": "...",
|
|
"rejected_response": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### llama3.icr
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"input": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### zephyr.nectar
|
|
|
|
```json
|
|
{
|
|
"prompt": "...",
|
|
"answers": [
|
|
{
|
|
"answer": "...",
|
|
"rank": 1
|
|
},
|
|
{
|
|
"answer": "...",
|
|
"rank": 2
|
|
}
|
|
// ... more answers with ranks
|
|
]
|
|
}
|
|
```
|
|
|
|
#### chat_template.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### chat_template.default
|
|
|
|
```yaml
|
|
rl: dpo
|
|
datasets:
|
|
- path: ...
|
|
split: train
|
|
type: chat_template.default
|
|
field_messages: "messages"
|
|
field_chosen: "chosen"
|
|
field_rejected: "rejected"
|
|
message_property_mappings:
|
|
role: role
|
|
content: content
|
|
roles:
|
|
user: ["user"]
|
|
assistant: ["assistant"]
|
|
system: ["system"]
|
|
```
|
|
|
|
Sample input format:
|
|
|
|
```json
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": "..."
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "..."
|
|
},
|
|
// ... more messages
|
|
],
|
|
"chosen": {
|
|
"role": "assistant",
|
|
"content": "..."
|
|
},
|
|
"rejected": {
|
|
"role": "assistant",
|
|
"content": "..."
|
|
}
|
|
}
|
|
```
|
|
|
|
#### user_defined.default
|
|
|
|
For custom behaviors,
|
|
|
|
```yaml
|
|
rl: dpo
|
|
datasets:
|
|
- path: ...
|
|
split: train
|
|
type:
|
|
field_prompt: "prompt"
|
|
field_system: "system"
|
|
field_chosen: "chosen"
|
|
field_rejected: "rejected"
|
|
prompt_format: "{prompt}"
|
|
chosen_format: "{chosen}"
|
|
rejected_format: "{rejected}"
|
|
```
|
|
|
|
The input format is a simple JSON input with customizable fields based on the above config.
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"chosen": "...",
|
|
"rejected": "..."
|
|
}
|
|
```
|
|
|
|
### IPO
|
|
|
|
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
|
|
|
```yaml
|
|
rl: ipo
|
|
```
|
|
|
|
### ORPO
|
|
|
|
Paper: https://arxiv.org/abs/2403.07691
|
|
|
|
```yaml
|
|
rl: orpo
|
|
orpo_alpha: 0.1
|
|
remove_unused_columns: false
|
|
|
|
chat_template: chatml
|
|
datasets:
|
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
|
type: chat_template.argilla
|
|
```
|
|
|
|
ORPO supports the following types with the following dataset format:
|
|
|
|
#### chat_template.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...", // if available, will be taken as user message for single-turn instead of from list below
|
|
|
|
// chosen/rejected should be same till last content and only even-number of alternating user/assistant turns
|
|
"chosen": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
### KTO
|
|
|
|
```yaml
|
|
rl: kto
|
|
rl_beta: 0.1 # default
|
|
kto_desirable_weight: 1.0 # default
|
|
kto_undesirable_weight: 1.0 # default
|
|
|
|
remove_unused_columns: false
|
|
|
|
datasets:
|
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned-kto
|
|
type: llama3.ultra
|
|
split: train
|
|
|
|
gradient_checkpointing: true
|
|
gradient_checkpointing_kwargs:
|
|
use_reentrant: true
|
|
```
|
|
|
|
KTO supports the following types with the following dataset format:
|
|
|
|
#### chatml.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"chosen": [
|
|
{"role": "user", "content": "..."}
|
|
],
|
|
"completion": [
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### chatml.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### chatml.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.argilla
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"instruction": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.argilla_chat
|
|
|
|
```json
|
|
{
|
|
"completion": [
|
|
{"role": "user", "content": "..."},
|
|
{"role": "assistant", "content": "..."}
|
|
]
|
|
}
|
|
```
|
|
|
|
#### llama3.intel
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"question": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.prompt_pairs
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### llama3.ultra
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "..."
|
|
}
|
|
```
|
|
|
|
#### user_defined.default
|
|
|
|
For custom behaviors,
|
|
|
|
```yaml
|
|
rl: kto
|
|
datasets:
|
|
- path: ...
|
|
split: train
|
|
type:
|
|
field_prompt: "prompt"
|
|
field_system: "system"
|
|
field_completion: "completion"
|
|
field_label: "label"
|
|
prompt_format: "{prompt}"
|
|
completion_format: "{completion}"
|
|
```
|
|
|
|
The input format is a simple JSON input with customizable fields based on the above config.
|
|
|
|
```json
|
|
{
|
|
"system": "...", // optional
|
|
"prompt": "...",
|
|
"completion": "...",
|
|
"label": "..."
|
|
}
|
|
```
|
|
|
|
### GRPO
|
|
|
|
::: {.callout-tip}
|
|
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/grpo_code). For a comprehensive guide covering async training, custom rewards, importance sampling, and scaling, see the [GRPO deep dive](grpo.qmd).
|
|
:::
|
|
|
|
In the latest GRPO implementation, `vLLM` is used to significantly speedup trajectory generation during training. In this example, we're using 4 GPUs - 2 for training, and 2 for vLLM:
|
|
|
|
::: {.callout-important}
|
|
Make sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`.
|
|
:::
|
|
|
|
```yaml
|
|
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
|
|
|
vllm:
|
|
host: 0.0.0.0
|
|
port: 8000
|
|
tensor_parallel_size: 2
|
|
gpu_memory_utilization: 0.85
|
|
dtype: auto
|
|
# max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand
|
|
|
|
rl: grpo
|
|
trl:
|
|
use_vllm: true
|
|
vllm_server_host: 0.0.0.0
|
|
vllm_server_port: 8000
|
|
vllm_server_timeout: 300
|
|
```
|
|
|
|
```bash
|
|
CUDA_VISIBLE_DEVICES=2,3 axolotl vllm-serve grpo.yaml
|
|
```
|
|
|
|
Your `vLLM` instance will now attempt to spin up, and it's time to kick off training utilizing our remaining two GPUs. In another terminal, execute:
|
|
|
|
```bash
|
|
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2
|
|
```
|
|
|
|
::: {.callout-note}
|
|
Due to TRL's implementation with vLLM, the vLLM instance must use the last N GPUs instead of the first N GPUs. This is why in the example above, we use `CUDA_VISIBLE_DEVICES=2,3` for the vLLM instance.
|
|
:::
|
|
|
|
#### Reward functions
|
|
|
|
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
|
|
|
For example, to load OpenAI's GSM8K and use a random reward for completions:
|
|
|
|
```python
|
|
# rewards.py
|
|
import random
|
|
|
|
def rand_reward_func(completions, **kwargs) -> list[float]:
|
|
return [random.uniform(0, 1) for _ in completions]
|
|
|
|
def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|
def transform_fn(example, tokenizer=None):
|
|
label = example["answer"].split("####")[-1].strip().replace(",", "")
|
|
return {
|
|
"prompt": [{"role": "user", "content": example["question"]},],
|
|
"answer": label,
|
|
}
|
|
return transform_fn, {"remove_columns": ["question"]}
|
|
```
|
|
|
|
```yaml
|
|
rl: grpo
|
|
|
|
trl:
|
|
beta: 0.001
|
|
max_completion_length: 256
|
|
use_vllm: True
|
|
num_generations: 4
|
|
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
|
reward_weights: [1.0]
|
|
datasets:
|
|
- path: openai/gsm8k
|
|
name: main
|
|
type: rewards.oai_gsm8k_transform # format: '{file_name}.{fn_name}'
|
|
```
|
|
|
|
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
|
|
|
|
To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py).
|
|
|
|
#### OpenEnv Rollout Functions
|
|
|
|
GRPO supports custom rollout functions for OpenEnv-style environments, enabling interactive tasks like web browsing, code execution, or tool use. This allows you to implement custom generation logic that interacts with external environments.
|
|
|
|
For example, to implement a simple math-solving environment with step-by-step verification:
|
|
|
|
```python
|
|
# math_env.py
|
|
import re
|
|
|
|
def math_solver_rollout(model, processing_class, prompts, generation_config=None):
|
|
"""
|
|
Custom rollout function that generates step-by-step math solutions.
|
|
|
|
Args:
|
|
model: The language model
|
|
processing_class: The tokenizer/processing_class
|
|
prompts: List of prompt dicts (with 'messages' key for chat format)
|
|
generation_config: Optional generation configuration
|
|
|
|
Returns:
|
|
List of completion strings
|
|
"""
|
|
completions = []
|
|
|
|
for prompt in prompts:
|
|
# Apply chat template to prompt
|
|
messages = prompt.get("messages", [])
|
|
formatted_prompt = processing_class.apply_chat_template(
|
|
messages, processing_class=False, add_generation_prompt=True
|
|
)
|
|
|
|
# Generate step-by-step solution
|
|
full_response = ""
|
|
for step in range(5): # Max 5 reasoning steps
|
|
current_input = formatted_prompt + full_response + "\nNext step:"
|
|
inputs = processing_class(current_input, return_tensors="pt").to(model.device)
|
|
|
|
outputs = model.generate(
|
|
**inputs,
|
|
max_new_tokens=100,
|
|
generation_config=generation_config,
|
|
)
|
|
step_text = processing_class.decode(
|
|
outputs[0][inputs.input_ids.shape[1]:],
|
|
skip_special_tokens=True
|
|
)
|
|
|
|
# Check if solution is complete
|
|
if "FINAL ANSWER:" in step_text:
|
|
full_response += step_text
|
|
break
|
|
full_response += step_text + "\n"
|
|
|
|
completions.append(full_response)
|
|
|
|
return completions
|
|
|
|
def math_reward(prompts, completions, answers, **kwargs):
|
|
"""Reward function that checks mathematical correctness"""
|
|
rewards = []
|
|
for completion, correct_answer in zip(completions, answers):
|
|
# Extract predicted answer
|
|
match = re.search(r"FINAL ANSWER:\s*(.+)", completion)
|
|
predicted = match.group(1).strip() if match else ""
|
|
|
|
# Compare with correct answer
|
|
reward = 1.0 if predicted == str(correct_answer) else 0.0
|
|
rewards.append(reward)
|
|
|
|
return rewards
|
|
|
|
def math_transform(cfg, *args, **kwargs):
|
|
"""Transform dataset to GRPO format with answer field"""
|
|
def transform_fn(example, processing_class=None):
|
|
return {
|
|
"prompt": [{"role": "user", "content": example["question"]}],
|
|
"answer": str(example["answer"]),
|
|
}
|
|
return transform_fn, {"remove_columns": ["question"]}
|
|
```
|
|
|
|
```yaml
|
|
rl: grpo
|
|
|
|
trl:
|
|
beta: 0.001
|
|
max_completion_length: 512
|
|
num_generations: 4
|
|
rollout_func: "math_env.math_solver_rollout" # Custom rollout function
|
|
reward_funcs: ["math_env.math_reward"]
|
|
reward_weights: [1.0]
|
|
|
|
datasets:
|
|
- path: openai/gsm8k
|
|
name: main
|
|
type: math_env.math_transform
|
|
```
|
|
|
|
The `rollout_func` parameter accepts a fully qualified name (e.g., `module_name.function_name`) that points to a callable function in your local directory. The function receives:
|
|
|
|
- `model`: The language model
|
|
- `processing_class`: The tokenizer/processing class
|
|
- `prompts`: List of prompt dictionaries
|
|
- `generation_config` (optional): Generation configuration
|
|
|
|
And should return a list of completion strings.
|
|
|
|
For more OpenEnv examples, see [TRL OpenEnv Documentation](https://huggingface.co/docs/trl/main/en/openenv).
|
|
|
|
#### GRPO with DAPO/Dr. GRPO loss
|
|
|
|
The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.
|
|
|
|
```yaml
|
|
trl:
|
|
loss_type: dr_grpo
|
|
# Normalizes loss based on max completion length (default: 256)
|
|
max_completion_length:
|
|
```
|
|
|
|
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
|
|
|
|
#### Async GRPO
|
|
|
|
Async GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step.
|
|
|
|
```yaml
|
|
trl:
|
|
use_data_producer: true # Enable data producer protocol
|
|
use_vllm: true
|
|
async_prefetch: true # Generate rollouts in background thread
|
|
prefetch_depth: 1 # Number of rollouts to prefetch
|
|
vllm_sync_interval: 2 # Sync weights to vLLM every N steps
|
|
```
|
|
|
|
::: {.callout-note}
|
|
Because the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by `vllm_importance_sampling_correction: true` (default when async is enabled).
|
|
:::
|
|
|
|
##### vLLM LoRA Sync
|
|
|
|
By default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels.
|
|
|
|
```yaml
|
|
adapter: lora
|
|
lora_r: 32
|
|
lora_alpha: 64
|
|
lora_target_linear: true
|
|
|
|
trl:
|
|
vllm_lora_sync: true # Enable native LoRA sync
|
|
```
|
|
|
|
When `vllm_lora_sync: true` is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual:
|
|
|
|
```bash
|
|
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
|
```
|
|
|
|
Then start training on a separate GPU:
|
|
|
|
```bash
|
|
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
|
|
```
|
|
|
|
::: {.callout-tip}
|
|
LoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation.
|
|
:::
|
|
|
|
##### Streaming Partial Batch
|
|
|
|
Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring.
|
|
|
|
```yaml
|
|
trl:
|
|
streaming_partial_batch: true
|
|
```
|
|
|
|
##### Importance Sampling Correction
|
|
|
|
When using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift.
|
|
|
|
```yaml
|
|
trl:
|
|
vllm_importance_sampling_correction: true # Enable IS correction
|
|
importance_sampling_level: token # 'token' or 'sequence'
|
|
off_policy_mask_threshold: 0.5 # Mask sequences with IS ratio below this
|
|
```
|
|
|
|
- `importance_sampling_level: token` applies per-token IS ratios (recommended with Liger kernel)
|
|
- `importance_sampling_level: sequence` applies per-sequence IS ratios
|
|
- `off_policy_mask_threshold` masks out sequences where the IS ratio indicates they are too far off-policy
|
|
|
|
##### Replay Buffer
|
|
|
|
The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches.
|
|
|
|
```yaml
|
|
trl:
|
|
replay_buffer_size: 100 # Max cached groups (0 = disabled)
|
|
replay_recompute_logps: true # Recompute log-probs for replayed data (recommended)
|
|
```
|
|
|
|
::: {.callout-note}
|
|
When `replay_recompute_logps: true` (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data.
|
|
:::
|
|
|
|
##### Deferred Re-rolling
|
|
|
|
Failed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them.
|
|
|
|
```yaml
|
|
trl:
|
|
reroll_start_fraction: 0.5 # Start re-rolling after 50% of training
|
|
reroll_max_groups: 1 # Max groups to replace per batch
|
|
```
|
|
|
|
##### Zero-Advantage Batch Skipping
|
|
|
|
When all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as `skipped_zero_adv_batches=1`.
|
|
|
|
```yaml
|
|
trl:
|
|
skip_zero_advantage_batches: true # default
|
|
```
|
|
|
|
##### Parallel Reward Workers
|
|
|
|
Reward functions that use `signal.alarm()` (e.g., `math_verify`) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation.
|
|
|
|
```yaml
|
|
trl:
|
|
reward_num_workers: 4 # Number of subprocess workers (1 = no parallelism)
|
|
```
|
|
|
|
##### Full Async GRPO Example
|
|
|
|
```yaml
|
|
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
|
|
|
vllm:
|
|
host: 0.0.0.0
|
|
port: 8000
|
|
gpu_memory_utilization: 0.35
|
|
dtype: auto
|
|
|
|
adapter: lora
|
|
lora_r: 32
|
|
lora_alpha: 64
|
|
lora_target_linear: true
|
|
|
|
rl: grpo
|
|
trl:
|
|
use_data_producer: true
|
|
use_vllm: true
|
|
async_prefetch: true
|
|
prefetch_depth: 1
|
|
vllm_sync_interval: 2
|
|
vllm_lora_sync: true
|
|
streaming_partial_batch: true
|
|
vllm_importance_sampling_correction: true
|
|
off_policy_mask_threshold: 0.5
|
|
importance_sampling_level: token
|
|
num_generations: 8
|
|
max_completion_length: 512
|
|
reward_funcs:
|
|
- rewards.accuracy_reward
|
|
reroll_start_fraction: 0.5
|
|
replay_buffer_size: 100
|
|
reward_num_workers: 4
|
|
skip_zero_advantage_batches: true
|
|
|
|
datasets:
|
|
- path: AI-MO/NuminaMath-TIR
|
|
type: rewards.prompt_transform
|
|
split: train
|
|
|
|
gradient_accumulation_steps: 4
|
|
micro_batch_size: 2
|
|
max_steps: 500
|
|
learning_rate: 1e-5
|
|
bf16: true
|
|
gradient_checkpointing: true
|
|
```
|
|
|
|
```bash
|
|
# Terminal 1: Start vLLM on GPU 0
|
|
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
|
|
|
# Terminal 2: Train on GPU 1
|
|
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
|
|
```
|
|
|
|
##### Multi-GPU Async GRPO
|
|
|
|
Async GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs.
|
|
|
|
**FSDP:**
|
|
|
|
```yaml
|
|
fsdp:
|
|
- full_shard
|
|
- auto_wrap
|
|
fsdp_config:
|
|
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
|
gradient_checkpointing_kwargs:
|
|
use_reentrant: false
|
|
```
|
|
|
|
**DeepSpeed ZeRO-3:**
|
|
|
|
```yaml
|
|
deepspeed: deepspeed_configs/zero3_bf16.json
|
|
gradient_checkpointing_kwargs:
|
|
use_reentrant: true # Required for ZeRO-3
|
|
```
|
|
|
|
```bash
|
|
# Terminal 1: Start vLLM on GPU 0
|
|
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
|
|
|
# Terminal 2: Train on GPUs 0,1
|
|
CUDA_VISIBLE_DEVICES=0,1 axolotl train config.yaml
|
|
```
|
|
|
|
::: {.callout-important}
|
|
With multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads.
|
|
:::
|
|
|
|
### GDPO
|
|
|
|
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.
|
|
|
|
::: {.callout-tip}
|
|
Use GDPO when training with multiple reward functions. For single reward, GRPO and GDPO produce equivalent results.
|
|
:::
|
|
|
|
Paper: [https://arxiv.org/pdf/2501.05242](https://arxiv.org/pdf/2501.05242)
|
|
|
|
GDPO uses TRL's native `multi_objective_aggregation` parameter under the hood. When you set `rl: gdpo`, axolotl automatically configures TRL to use `normalize_then_sum` aggregation.
|
|
|
|
```yaml
|
|
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
|
|
|
vllm:
|
|
host: 0.0.0.0
|
|
port: 8000
|
|
tensor_parallel_size: 2
|
|
gpu_memory_utilization: 0.85
|
|
|
|
rl: gdpo
|
|
|
|
trl:
|
|
beta: 0.001
|
|
max_completion_length: 256
|
|
use_vllm: true
|
|
num_generations: 4
|
|
reward_funcs:
|
|
- rewards.format_reward
|
|
- rewards.correctness_reward
|
|
reward_weights: [1.0, 2.0]
|
|
|
|
datasets:
|
|
- path: openai/gsm8k
|
|
name: main
|
|
type: rewards.oai_gsm8k_transform
|
|
```
|
|
|
|
You can also use GRPO with explicit aggregation control:
|
|
|
|
```yaml
|
|
rl: grpo
|
|
trl:
|
|
multi_objective_aggregation: normalize_then_sum # GDPO behavior
|
|
# or: sum_then_normalize # Default GRPO behavior
|
|
```
|
|
|
|
#### GDPO vs GRPO
|
|
|
|
| Aspect | GRPO | GDPO |
|
|
|--------|------|------|
|
|
| **Aggregation** | `sum_then_normalize` | `normalize_then_sum` |
|
|
| **Multi-reward** | May collapse advantages | Preserves reward signals |
|
|
| **Single reward** | Standard behavior | Equivalent to GRPO |
|
|
|
|
#### Why GDPO?
|
|
|
|
When using multiple rewards with GRPO, different reward combinations can produce identical advantages:
|
|
|
|
```
|
|
# Example: format + correctness rewards
|
|
[format=0, correct=3] → sum=3
|
|
[format=1, correct=2] → sum=3 ← GRPO sees these as equal!
|
|
[format=2, correct=1] → sum=3
|
|
[format=3, correct=0] → sum=3
|
|
```
|
|
|
|
GDPO normalizes each reward independently, preserving their relative differences.
|
|
|
|
#### Reward Functions
|
|
|
|
GDPO uses the same reward function format as GRPO:
|
|
|
|
```python
|
|
# rewards.py
|
|
def format_reward(completions, **kwargs) -> list[float]:
|
|
return [1.0 if len(c) > 10 else 0.0 for c in completions]
|
|
|
|
def correctness_reward(completions, answers, **kwargs) -> list[float]:
|
|
rewards = []
|
|
for completion, answer in zip(completions, answers):
|
|
# Your scoring logic here
|
|
rewards.append(score)
|
|
return rewards
|
|
```
|
|
|
|
#### Sequence Parallelism
|
|
|
|
GDPO supports sequence parallelism for long-context training:
|
|
|
|
```yaml
|
|
rl: gdpo
|
|
context_parallel_size: 2
|
|
```
|
|
|
|
### SimPO
|
|
|
|
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
|
|
|
|
```yaml
|
|
rl: simpo
|
|
rl_beta: 0.1 # default in CPOTrainer
|
|
cpo_alpha: 1.0 # default in CPOTrainer
|
|
simpo_gamma: 0.5 # default in CPOTrainer
|
|
```
|
|
|
|
This method uses the same dataset format as [DPO](#dpo).
|
|
|
|
### EBFT {#ebft}
|
|
|
|
::: {.callout-tip}
|
|
For a detailed guide on EBFT modes, feature extraction, and configuration, see the [EBFT guide](ebft.qmd).
|
|
:::
|
|
|
|
EBFT (Energy-Based Fine-Tuning) fine-tunes language models by optimizing a **feature-matching loss** rather than relying on external reward functions. A frozen copy of the model extracts embeddings from both generated and ground-truth completions, and the generator is updated via REINFORCE to match the ground-truth feature moments.
|
|
|
|
Paper: ["Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"](https://arxiv.org/abs/2603.12248) (Jelassi et al., 2026)
|
|
|
|
**Key advantages:**
|
|
|
|
- No reward model or verifier required — works on any (prompt, completion) data
|
|
- Applicable to non-verifiable tasks (code, translation, creative writing)
|
|
- Operates on model rollouts (not teacher forcing), reducing distribution shift
|
|
|
|
EBFT supports two modes:
|
|
|
|
- **Structured mode**: For QA/instruction data with prompt + completion pairs. Uses vLLM for generation (like GRPO).
|
|
- **Strided mode**: For unstructured text without prompt/completion splits. Uses strided block-parallel generation with flex_attention — no vLLM needed.
|
|
|
|
#### Structured Mode
|
|
|
|
```yaml
|
|
base_model: Qwen/Qwen3-4B
|
|
|
|
rl: ebft
|
|
|
|
ebft:
|
|
feature_layers: [0.25, 0.5, 0.75] # Extract features at 25%, 50%, 75% depth
|
|
embed_method: last_token
|
|
use_whitening: false
|
|
alignment_coef: 1.0 # Cosine similarity reward weight
|
|
diversity_coef: 1.0 # Pairwise dot product penalty
|
|
ce_coef: 0.0 # Cross-entropy on GT tokens (0 = off)
|
|
|
|
trl:
|
|
num_generations: 4
|
|
max_completion_length: 256
|
|
temperature: 0.7
|
|
use_vllm: true
|
|
vllm_server_host: 0.0.0.0
|
|
vllm_server_port: 8000
|
|
vllm_lora_sync: true # LoRA adapter sync (recommended)
|
|
vllm_sync_interval: 3
|
|
use_data_producer: true
|
|
async_prefetch: true # Set false for sync mode
|
|
scale_rewards: true
|
|
loss_type: grpo
|
|
epsilon: 0.2
|
|
|
|
vllm:
|
|
gpu_memory_utilization: 0.5
|
|
max_model_len: 2048
|
|
|
|
datasets:
|
|
- path: nvidia/OpenCodeInstruct
|
|
type: ebft_opencode.transform
|
|
split: train[:500]
|
|
|
|
adapter: lora
|
|
lora_r: 16
|
|
lora_alpha: 32
|
|
lora_target_linear: true
|
|
```
|
|
|
|
```bash
|
|
# Terminal 1: Start vLLM
|
|
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
|
|
|
# Terminal 2: Train
|
|
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
|
|
```
|
|
|
|
#### Strided Mode
|
|
|
|
For unstructured text (raw code, prose). No vLLM needed — runs on a single GPU.
|
|
|
|
```yaml
|
|
base_model: meta-llama/Llama-3.2-1B
|
|
|
|
rl: ebft
|
|
|
|
ebft:
|
|
mode: strided
|
|
stride: 8
|
|
context_length: 8
|
|
generate_max_len: 8
|
|
n_samples_per_prompt: 4
|
|
temperature: 0.6
|
|
feature_layers: [0.25, 0.5, 0.75]
|
|
embed_method: last_token
|
|
use_whitening: true
|
|
alignment_coef: 1.0
|
|
diversity_coef: 1.0
|
|
rl_coef: 1.0
|
|
ce_coef: 0.03
|
|
advantage_estimator: rloo
|
|
|
|
datasets:
|
|
- path: nvidia/OpenCodeInstruct
|
|
type: ebft_strided_structured.transform
|
|
split: train[:1%]
|
|
|
|
flash_attention: false
|
|
flex_attention: true # Strided mode uses flex_attention
|
|
gradient_checkpointing: true
|
|
gradient_checkpointing_kwargs:
|
|
use_reentrant: true # Required for flex_attention
|
|
```
|
|
|
|
```bash
|
|
CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml
|
|
```
|
|
|
|
::: {.callout-tip}
|
|
See `examples/ebft/` for complete example configs covering Llama 1B/3B/8B and Qwen3 4B/8B models in both modes.
|
|
:::
|
|
|
|
#### EBFT Configuration Reference
|
|
|
|
| Parameter | Default | Description |
|
|
|-----------|---------|-------------|
|
|
| `ebft.feature_layers` | `[0.25, 0.5, 0.75]` | Layer depths for feature extraction (fractional) |
|
|
| `ebft.embed_method` | `last_token` | Feature pooling: `last_token`, `mean_pooling`, `concat` |
|
|
| `ebft.use_whitening` | `false` | SVD whitening of feature dimensions |
|
|
| `ebft.alignment_coef` | `1.0` | Cosine similarity reward weight |
|
|
| `ebft.diversity_coef` | `1.0` | Pairwise dot product penalty weight |
|
|
| `ebft.ce_coef` | `0.0` | Cross-entropy loss on ground-truth tokens |
|
|
| `ebft.mode` | `structured` | `structured` (vLLM) or `strided` (no vLLM) |
|
|
| `ebft.stride` | — | Tokens between anchor points (strided mode) |
|
|
| `ebft.context_length` | — | Context window per block (strided mode) |
|
|
| `ebft.generate_max_len` | — | Tokens to generate per block (strided mode) |
|
|
| `ebft.n_samples_per_prompt` | — | Rollouts per document (strided mode) |
|
|
| `ebft.advantage_estimator` | `grpo` | `grpo` or `rloo` (strided mode) |
|
|
|
|
### NeMo Gym Integration
|
|
|
|
[NeMo Gym](https://github.com/NVIDIA-NeMo/Gym) provides 50+ verified RL environments (math, coding, tool-use, reasoning) with deterministic reward signals. The axolotl integration supports both **single-turn** (call `/verify` after generation) and **multi-turn** (agent-based tool execution via `/run`).
|
|
|
|
#### Single-Turn (Simplest)
|
|
|
|
For environments that only need answer verification (math, coding challenges). No agent server needed — the reward function calls `/verify` directly on the resource server.
|
|
|
|
```yaml
|
|
base_model: Qwen/Qwen2.5-0.5B-Instruct
|
|
|
|
rl: grpo
|
|
chat_template: tokenizer_default
|
|
|
|
trl:
|
|
use_vllm: false # Colocate mode (single GPU)
|
|
num_generations: 4
|
|
max_completion_length: 128
|
|
temperature: 0.9
|
|
reward_funcs:
|
|
- axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify
|
|
|
|
plugins:
|
|
- axolotl.integrations.nemo_gym.NemoGymPlugin
|
|
|
|
nemo_gym_enabled: true
|
|
nemo_gym_dir: ~/Gym
|
|
nemo_gym_auto_start: false
|
|
nemo_gym_head_port: 11000
|
|
nemo_gym_datasets:
|
|
- path: resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl
|
|
server_name: reasoning_gym
|
|
|
|
datasets:
|
|
- path: ~/Gym/resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl
|
|
type: chat_template
|
|
field_messages: responses_create_params.input
|
|
message_field_content: content
|
|
message_field_role: role
|
|
```
|
|
|
|
```bash
|
|
# Terminal 1: Start NeMo Gym resource server
|
|
cd ~/Gym && .venv/bin/ng_run \
|
|
"+config_paths=[resources_servers/reasoning_gym/configs/resources_only.yaml]" \
|
|
"+skip_venv_if_present=true"
|
|
|
|
# Terminal 2: Train
|
|
CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml
|
|
```
|
|
|
|
::: {.callout-note}
|
|
`nemo_gym_datasets.path` is relative to `nemo_gym_dir`. Don't use absolute paths or they will be double-joined.
|
|
:::
|
|
|
|
#### Multi-Turn with Async GRPO (Recommended)
|
|
|
|
For environments with tool-use (weather, search, databases). An agent server orchestrates multi-turn interactions: generate → parse tool calls → execute tools → feed results back → repeat until done.
|
|
|
|
```yaml
|
|
base_model: Qwen/Qwen3-0.6B
|
|
|
|
rl: grpo
|
|
chat_template: tokenizer_default
|
|
|
|
adapter: lora
|
|
lora_r: 16
|
|
lora_alpha: 32
|
|
lora_target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]
|
|
|
|
trl:
|
|
use_vllm: true
|
|
vllm_mode: server
|
|
vllm_server_host: localhost
|
|
vllm_server_port: 8000
|
|
vllm_lora_sync: true
|
|
vllm_sync_interval: 5
|
|
use_data_producer: true
|
|
async_prefetch: true # 3x speedup
|
|
num_generations: 4
|
|
max_completion_length: 512
|
|
temperature: 0.8
|
|
reward_funcs:
|
|
- axolotl.integrations.nemo_gym.rewards.reward_env
|
|
|
|
plugins:
|
|
- axolotl.integrations.nemo_gym.NemoGymPlugin
|
|
|
|
nemo_gym_enabled: true
|
|
nemo_gym_auto_start: false
|
|
nemo_gym_head_port: 11000
|
|
nemo_gym_multi_turn: true
|
|
nemo_gym_verify_timeout: 120
|
|
nemo_gym_datasets:
|
|
- path: resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl
|
|
server_name: example_single_tool_call
|
|
|
|
datasets:
|
|
- path: ~/Gym/resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl
|
|
type: chat_template
|
|
field_messages: responses_create_params.input
|
|
message_field_content: content
|
|
message_field_role: role
|
|
|
|
vllm:
|
|
gpu_memory_utilization: 0.85
|
|
max_model_len: 2048
|
|
```
|
|
|
|
Multi-turn requires three services running:
|
|
|
|
```bash
|
|
# Terminal 1: vLLM with LoRA + tool calling
|
|
VLLM_ALLOW_RUNTIME_LORA_UPDATING=1 CUDA_VISIBLE_DEVICES=0 \
|
|
python -m vllm.entrypoints.openai.api_server \
|
|
--model Qwen/Qwen3-0.6B --max-model-len 2048 \
|
|
--gpu-memory-utilization 0.85 \
|
|
--enable-lora --max-lora-rank 64 \
|
|
--enable-auto-tool-choice --tool-call-parser hermes
|
|
|
|
# Terminal 2: NeMo Gym servers (resource + model proxy + agent)
|
|
cd ~/Gym && .venv/bin/ng_run \
|
|
"+config_paths=[configs/axolotl_tool_calling.yaml]" \
|
|
"+skip_venv_if_present=true"
|
|
|
|
# Terminal 3: Training
|
|
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
|
|
```
|
|
|
|
::: {.callout-important}
|
|
Multi-turn requires a NeMo Gym agent config YAML that defines three components: a resource server (tools + `/verify`), a model server proxy (forwards to your vLLM), and an agent server (orchestrates `/run`). See the [NeMo Gym README](https://github.com/NVIDIA-NeMo/Gym) for agent config format.
|
|
:::
|
|
|
|
#### NeMo Gym Prerequisites
|
|
|
|
```bash
|
|
# Clone and set up NeMo Gym
|
|
git clone https://github.com/NVIDIA-NeMo/Gym.git ~/Gym
|
|
cd ~/Gym
|
|
uv venv --python 3.12 && source .venv/bin/activate && uv sync
|
|
|
|
# Fix pycosat build (GCC 13+)
|
|
CFLAGS="" uv pip install pycosat --python .venv/bin/python --no-build-isolation
|
|
```
|
|
|
|
#### NeMo Gym Configuration Reference
|
|
|
|
| Parameter | Type | Default | Description |
|
|
|-----------|------|---------|-------------|
|
|
| `nemo_gym_enabled` | bool | — | Enable the NeMo Gym integration |
|
|
| `nemo_gym_dir` | str | `~/Gym` | Path to NeMo Gym repo |
|
|
| `nemo_gym_auto_start` | bool | `true` | Auto-start resource servers |
|
|
| `nemo_gym_head_port` | int | `11000` | Head server port |
|
|
| `nemo_gym_multi_turn` | bool | `false` | Enable multi-turn via agent `/run` |
|
|
| `nemo_gym_verify_timeout` | int | `30` | Per-request timeout (seconds) |
|
|
| `nemo_gym_datasets` | list | required | Dataset configs with `path` and `server_name` |
|
|
|
|
#### Reward Functions
|
|
|
|
| Function | Mode | Description |
|
|
|----------|------|-------------|
|
|
| `axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify` | Single-turn | Calls `/verify`, returns binary reward |
|
|
| `axolotl.integrations.nemo_gym.rewards.reward_env` | Multi-turn | Passthrough reward from agent `/run` |
|
|
|
|
### Using local dataset files
|
|
|
|
```yaml
|
|
datasets:
|
|
- ds_type: json
|
|
data_files:
|
|
- orca_rlhf.jsonl
|
|
split: train
|
|
type: chatml.intel
|
|
```
|
|
|
|
### TRL auto-unwrapping for PEFT
|
|
|
|
TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:
|
|
|
|
```yaml
|
|
# load ref model when adapter training.
|
|
rl_adapter_ref_model: true
|
|
```
|