* upgrade to torchao 0.17.0 * chore: lint * refactor attention handling * replace legacy attention boolean flags with capability properties Replace checks with capability-based properties derived from attn_implementation This separates three concerns that were conflated under flash_attention: 1. Backend selection -> attn_implementation enum 2. Packing capability -> attn_supports_packing property 3. Flash-attn library dependency -> attn_uses_flash_lib property * compute attn capability flags in normalizer instead of properties * make attn_implementation the single source of truth * move attention-dependent validators to mode=after * migrate remaining consumers to canonical attn_implementation * expand attention tests + rewrite docs * migrate example configs to canonical attn_implementation * update doc snippets + reject gemma4-hybrid with non-FA2 backend * remove dead gemma4 branch in _set_attention_config * fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests * drop "Phase 2" naming from attn-implementation tests * regroup attn_implementation tests by feature concern * clean up verbose comments and remove MD Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai> * fix(collator): pass return_dict=True at apply_chat_template top level for transformers 5.x In transformers 5.x, ProcessorMixin.apply_chat_template gained its own `return_dict` parameter (defaulting to False). When return_dict=False and tokenize=True the method returns out["input_ids"] directly — a 2-D tensor — rather than the full BatchFeature dict. The old code placed `return_dict=True` inside processor_kwargs. In transformers 5.x those kwargs are forwarded to the underlying processor call self(...) where _merge_kwargs silently ignores any key not present in MllamaProcessorKwargs (emitting a warning). The outer return_dict therefore stayed False, apply_chat_template returned the raw input_ids tensor, and the subsequent `batch["input_ids"]` attempted to index a 2-D tensor with the 9-character string "input_ids", producing: IndexError: too many indices for tensor of dimension 2 The fix is to pass return_dict=True as a top-level keyword argument to apply_chat_template (where it is actually consumed) and remove it from processor_kwargs (where it was silently dropped). No version guard is needed: transformers is pinned to ==5.5.4 in pyproject.toml. Adds a unit-level regression test (tests/test_mm_chat_collator.py) that mocks the processor to return a raw tensor when apply_chat_template is called without top-level return_dict=True, verifying the four invariants: process_rows returns a dict, input_ids is 2-D, labels is 2-D, and apply_chat_template receives return_dict=True as a top-level kwarg. Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_multimodal_dataset Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_text_only_dataset Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai> * fix(collator): process_rows returns dict (BatchFeature) shape Two related changes for the multimodal chat collator under transformers 5.x: 1. Wrap apply_chat_template result in dict(...) so process_rows returns a plain dict rather than a BatchFeature instance. BatchFeature is a Mapping but not a dict; downstream code that did batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"]) would index on a tensor when the result wasn't dict-shaped, raising IndexError: too many indices for tensor of dimension 2 2. Soften the regression test's contract from `dict` to `Mapping` so it exercises the actual semantic guarantee (key/value access) rather than the implementation detail (dict vs BatchFeature). Test guards against the original transformers 5.x breakage where apply_chat_template's return_dict default went from True to False. Includes regression test under tests/test_mm_chat_collator.py. Bug surfaced via swarm dispatch task_01KQHPNAYD8XARSNSDJVW1GPF6 against attn-implementation-refactor; squash-merged from agent commits 4de886fd + dc9fcf4f. Signed-off-by: Wing Lian <wing@axolotl.ai> --------- Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
612 lines
24 KiB
Plaintext
612 lines
24 KiB
Plaintext
---
|
|
title: "GRPO Training"
|
|
description: "Group Relative Policy Optimization — a reinforcement learning method for training language models with verifiable reward functions."
|
|
order: 8
|
|
---
|
|
|
|
## Overview
|
|
|
|
Group Relative Policy Optimization (GRPO) is a reinforcement learning method that improves language models by generating multiple completions per prompt, scoring them with reward functions, and using the relative ranking within each group to compute advantage estimates. Unlike DPO, which requires pre-collected preference pairs, GRPO generates its own training data online and can work with any programmatic reward signal (math correctness, format compliance, code execution results, etc.).
|
|
|
|
Use GRPO when you have a task with a verifiable reward signal and want the model to discover solution strategies on its own. Use DPO when you already have human preference data. Use SFT when you have gold-standard completions to imitate directly.
|
|
|
|
Axolotl's GRPO implementation builds on TRL and adds async generation, streaming scoring, importance sampling correction, replay buffers, and multi-GPU scaling via FSDP and DeepSpeed.
|
|
|
|
|
|
## Architecture
|
|
|
|
GRPO training uses a two-process architecture: a vLLM server for fast generation and a trainer process for scoring and gradient updates.
|
|
|
|
```
|
|
Terminal 1 (GPU 0) Terminal 2 (GPU 1)
|
|
┌──────────────────────┐ ┌──────────────────────────────────┐
|
|
│ vLLM Server │ │ Trainer │
|
|
│ │ HTTP │ │
|
|
│ Serves base model │◄────────────►│ Background thread: │
|
|
│ + LoRA adapter │ /generate │ Send prompts to vLLM │
|
|
│ │ /set_lora │ Pad & collate completions │
|
|
│ Punica kernels for │ │ │
|
|
│ LoRA inference │ │ Main thread: │
|
|
│ │ │ Score completions (rewards) │
|
|
└──────────────────────┘ │ Compute policy log-probs │
|
|
│ Calculate advantages │
|
|
│ PPO-clip gradient update │
|
|
│ Sync LoRA weights to vLLM │
|
|
└──────────────────────────────────┘
|
|
```
|
|
|
|
**Data flow for each training step:**
|
|
|
|
1. The background thread sends prompts to vLLM, which generates `num_generations` completions per prompt.
|
|
2. The main thread scores completions using your reward functions.
|
|
3. Advantages are computed within each prompt group (group-relative normalization).
|
|
4. Policy log-probabilities are computed by running a forward pass on the training model.
|
|
5. The PPO-clip loss is computed and gradients are applied.
|
|
6. Periodically, LoRA adapter weights are synced back to vLLM so future generations reflect the updated policy.
|
|
|
|
With async prefetch enabled, step 1 for the *next* batch runs concurrently with steps 2-6 for the *current* batch.
|
|
|
|
|
|
## Quick Start
|
|
|
|
A GRPO training run requires three components: a YAML config, a reward module (Python file), and a running vLLM server.
|
|
|
|
### 1. Write a reward module
|
|
|
|
Create a file called `rewards.py` in your working directory:
|
|
|
|
```python
|
|
# rewards.py
|
|
import re
|
|
|
|
|
|
def accuracy_reward(completions, answer, **kwargs) -> list[float]:
|
|
"""Check if the completion contains the correct numerical answer."""
|
|
rewards = []
|
|
for completion, correct in zip(completions, answer):
|
|
text = completion[0]["content"]
|
|
# Extract the last number from the completion
|
|
numbers = re.findall(r"-?\d+(?:\.\d+)?", text)
|
|
predicted = numbers[-1] if numbers else ""
|
|
rewards.append(1.0 if predicted == str(correct) else 0.0)
|
|
return rewards
|
|
|
|
|
|
def format_reward(completions, **kwargs) -> list[float]:
|
|
"""Reward completions that use a structured thinking format."""
|
|
rewards = []
|
|
for completion in completions:
|
|
text = completion[0]["content"]
|
|
has_think = "<think>" in text and "</think>" in text
|
|
has_answer = "<answer>" in text and "</answer>" in text
|
|
rewards.append(1.0 if has_think and has_answer else 0.0)
|
|
return rewards
|
|
|
|
|
|
def prompt_transform(cfg, *args, **kwargs):
|
|
"""Convert GSM8K dataset rows into chat prompts."""
|
|
def transform_fn(example, tokenizer=None):
|
|
label = example["answer"].split("####")[-1].strip().replace(",", "")
|
|
return {
|
|
"prompt": [
|
|
{"role": "system", "content": "Solve the math problem. Show your reasoning in <think> tags and your final numerical answer in <answer> tags."},
|
|
{"role": "user", "content": example["question"]},
|
|
],
|
|
"answer": label,
|
|
}
|
|
return transform_fn, {"remove_columns": ["question"]}
|
|
```
|
|
|
|
### 2. Write the config
|
|
|
|
Create `config.yaml`:
|
|
|
|
```yaml
|
|
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
|
|
|
rl: grpo
|
|
chat_template: tokenizer_default
|
|
|
|
vllm:
|
|
host: 0.0.0.0
|
|
port: 8000
|
|
gpu_memory_utilization: 0.85
|
|
dtype: auto
|
|
max_model_len: 2048
|
|
|
|
adapter: lora
|
|
lora_r: 32
|
|
lora_alpha: 64
|
|
lora_target_linear: true
|
|
|
|
trl:
|
|
use_vllm: true
|
|
use_data_producer: true
|
|
vllm_server_host: 0.0.0.0
|
|
vllm_server_port: 8000
|
|
vllm_server_timeout: 300
|
|
vllm_lora_sync: true
|
|
num_generations: 8
|
|
max_completion_length: 512
|
|
temperature: 0.7
|
|
reward_funcs:
|
|
- rewards.accuracy_reward
|
|
- rewards.format_reward
|
|
reward_weights:
|
|
- 1.0
|
|
- 0.5
|
|
|
|
datasets:
|
|
- path: openai/gsm8k
|
|
name: main
|
|
type: rewards.prompt_transform
|
|
split: train
|
|
|
|
skip_prepare_dataset: true
|
|
val_set_size: 0.0
|
|
sequence_len: 512
|
|
micro_batch_size: 2
|
|
gradient_accumulation_steps: 4
|
|
max_steps: 200
|
|
learning_rate: 5.0e-6
|
|
optimizer: adamw_torch_fused
|
|
lr_scheduler: cosine
|
|
warmup_steps: 10
|
|
|
|
bf16: true
|
|
attn_implementation: flash_attention_2
|
|
gradient_checkpointing: true
|
|
|
|
special_tokens:
|
|
pad_token: "<|endoftext|>"
|
|
|
|
output_dir: ./grpo-output
|
|
logging_steps: 1
|
|
```
|
|
|
|
### 3. Start vLLM and train
|
|
|
|
```bash
|
|
# Terminal 1: Start vLLM server on GPU 0
|
|
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
|
|
|
# Wait 30-90 seconds for model loading and CUDA graph capture
|
|
|
|
# Terminal 2: Train on GPU 1
|
|
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
|
|
```
|
|
|
|
:::{.callout-tip}
|
|
Use `tmux` or separate terminal sessions to manage the two processes. The vLLM server must remain running for the entire training duration.
|
|
:::
|
|
|
|
|
|
## Custom Reward Functions
|
|
|
|
### Function signature
|
|
|
|
TRL calls reward functions with this signature:
|
|
|
|
```python
|
|
def my_reward(completions, **kwargs) -> list[float]:
|
|
```
|
|
|
|
- `completions` is a list of single-element lists, where each element is a dict `{"role": "assistant", "content": "..."}`. So `completions[i][0]["content"]` gives you the text of the i-th completion.
|
|
- `**kwargs` contains all dataset columns that were *not* removed by the dataset transform. This is how you pass ground truth answers, metadata, or any other information to your reward function.
|
|
- Return a `list[float]` with the same length as `completions`. You may return `None` for individual elements to exclude them from aggregation.
|
|
|
|
### Example: accuracy reward with answer extraction
|
|
|
|
```python
|
|
def accuracy_reward(completions, answer, **kwargs) -> list[float]:
|
|
rewards = []
|
|
for completion, correct_answer in zip(completions, answer):
|
|
text = completion[0]["content"]
|
|
# Extract answer from <answer>...</answer> tags
|
|
match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
|
|
predicted = match.group(1).strip() if match else ""
|
|
rewards.append(1.0 if predicted == str(correct_answer) else 0.0)
|
|
return rewards
|
|
```
|
|
|
|
### Example: length penalty
|
|
|
|
```python
|
|
def length_penalty(completions, **kwargs) -> list[float]:
|
|
"""Penalize very short or very long completions."""
|
|
rewards = []
|
|
for completion in completions:
|
|
length = len(completion[0]["content"])
|
|
if length < 50:
|
|
rewards.append(-0.5)
|
|
elif length > 2000:
|
|
rewards.append(-0.2)
|
|
else:
|
|
rewards.append(0.0)
|
|
return rewards
|
|
```
|
|
|
|
### Multiple rewards and weighting
|
|
|
|
You can combine multiple reward functions with different weights:
|
|
|
|
```yaml
|
|
trl:
|
|
reward_funcs:
|
|
- rewards.accuracy_reward
|
|
- rewards.format_reward
|
|
- rewards.length_penalty
|
|
reward_weights:
|
|
- 1.0 # accuracy is most important
|
|
- 0.5 # format compliance
|
|
- 0.1 # mild length preference
|
|
```
|
|
|
|
Rewards are combined by the `multi_objective_aggregation` strategy:
|
|
|
|
- `sum_then_normalize` (default): weights and sums all rewards first, then normalizes across the group.
|
|
- `normalize_then_sum` (GDPO): normalizes each reward independently, then sums. This prevents one reward from dominating and is recommended when using multiple reward functions with different scales.
|
|
|
|
```yaml
|
|
trl:
|
|
multi_objective_aggregation: normalize_then_sum
|
|
```
|
|
|
|
### Dataset transforms
|
|
|
|
The dataset transform converts raw HuggingFace dataset rows into chat-format prompts:
|
|
|
|
```python
|
|
def prompt_transform(cfg, *args, **kwargs):
|
|
def map_fn(example, tokenizer=None):
|
|
return {
|
|
"prompt": [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": example["question"]},
|
|
],
|
|
# Keep 'answer' column for the reward function
|
|
"answer": example["answer"],
|
|
}
|
|
# Remove columns consumed by the transform; keep columns needed by rewards
|
|
return map_fn, {"remove_columns": ["question"]}
|
|
```
|
|
|
|
The transform returns a tuple of `(map_function, kwargs_dict)`. The `remove_columns` in the kwargs dict removes columns that are no longer needed. Columns that your reward functions reference via `**kwargs` (like `answer`) must *not* be removed.
|
|
|
|
:::{.callout-warning}
|
|
The reward module must be importable from the directory where you run `axolotl train`. If your reward file is `rewards.py`, the import path is `rewards.accuracy_reward`. If it is inside a package `my_rewards/scoring.py`, use `my_rewards.scoring.accuracy_reward`.
|
|
:::
|
|
|
|
### Reward models (neural network rewards)
|
|
|
|
Instead of a Python function, you can pass a HuggingFace model path as a reward function. TRL will load it as a reward model and use its scalar output as the reward:
|
|
|
|
```yaml
|
|
trl:
|
|
reward_funcs:
|
|
- OpenAssistant/reward-model-deberta-v3-large-v2
|
|
- rewards.format_reward
|
|
reward_weights:
|
|
- 1.0
|
|
- 0.3
|
|
```
|
|
|
|
### Using math_verify
|
|
|
|
The `math_verify` library provides robust mathematical answer verification but uses `signal.alarm()` internally, which only works in the main thread. If you use `math_verify` in a reward function, set `reward_num_workers` to use subprocess workers:
|
|
|
|
```yaml
|
|
trl:
|
|
reward_num_workers: 4
|
|
```
|
|
|
|
Each worker runs in its own subprocess with its own main thread, so `signal.alarm()` works correctly.
|
|
|
|
|
|
## vLLM Setup
|
|
|
|
GRPO requires a running vLLM server for generation. For a complete guide on server modes, LoRA sync, weight synchronization, and restart procedures, see [vLLM Serving](vllm_serving.qmd).
|
|
|
|
The minimal setup:
|
|
|
|
```yaml
|
|
vllm:
|
|
host: 0.0.0.0
|
|
port: 8000
|
|
gpu_memory_utilization: 0.85
|
|
|
|
trl:
|
|
use_vllm: true
|
|
vllm_lora_sync: true # Recommended with LoRA — faster sync, no NCCL contention
|
|
vllm_sync_interval: 5 # Sync weights every 5 steps
|
|
```
|
|
|
|
```bash
|
|
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml # GPU 0: vLLM
|
|
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml # GPU 1: training
|
|
```
|
|
|
|
:::{.callout-warning}
|
|
vLLM must be restarted between experiments — stale weight syncs corrupt server state. See [Restart Requirements](vllm_serving.qmd#sec-restart).
|
|
:::
|
|
|
|
|
|
## Async Training Features
|
|
|
|
Async GRPO overlaps generation and training to reduce wall-clock time. While the model trains on the current batch, the next batch is already being generated by vLLM.
|
|
|
|
### Enabling async prefetch
|
|
|
|
```yaml
|
|
trl:
|
|
use_data_producer: true
|
|
async_prefetch: true
|
|
prefetch_depth: 1
|
|
vllm_sync_interval: 2
|
|
```
|
|
|
|
- `use_data_producer: true` enables the data producer protocol (required for all async features).
|
|
- `async_prefetch: true` runs generation in a background thread.
|
|
- `prefetch_depth` controls how many batches to prefetch ahead (1 is usually sufficient).
|
|
- `vllm_sync_interval` controls how often LoRA weights are synced to vLLM (every N optimizer steps). Lower values mean fresher generations but more sync overhead.
|
|
|
|
:::{.callout-tip}
|
|
Because the background thread generates with slightly stale model weights, async mode benefits from importance sampling correction (see next section). Enable `vllm_importance_sampling_correction: true` when using `async_prefetch: true`.
|
|
:::
|
|
|
|
### Streaming partial batch
|
|
|
|
Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This reduces peak memory during scoring and enables finer-grained zero-advantage skipping.
|
|
|
|
```yaml
|
|
trl:
|
|
streaming_partial_batch: true
|
|
streaming_min_groups: 1
|
|
```
|
|
|
|
`streaming_min_groups` controls the minimum number of prompt groups scored per chunk. Setting it to 1 gives maximum granularity.
|
|
|
|
### Zero-advantage batch skipping
|
|
|
|
When all advantages in a micro-batch are zero (every completion in the group got the same reward), there is no learning signal. This feature skips the forward/backward pass entirely for such micro-batches.
|
|
|
|
```yaml
|
|
trl:
|
|
skip_zero_advantage_batches: true # default
|
|
```
|
|
|
|
This is enabled by default and logged as `skipped_zero_adv_batches` in training metrics. It is a safety net, not a major optimization -- it only saves significant time when the model cannot solve any prompts in the batch.
|
|
|
|
### Replay buffer
|
|
|
|
The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and replaces zero-signal groups in later batches. This improves data utilization when many prompts yield no reward variance.
|
|
|
|
```yaml
|
|
trl:
|
|
replay_buffer_size: 100
|
|
replay_recompute_logps: true
|
|
```
|
|
|
|
:::{.callout-warning}
|
|
When `replay_recompute_logps: false`, replayed data uses stale log-probabilities which creates an IS mismatch. Keep the default `true` unless you have a specific reason to disable it.
|
|
:::
|
|
|
|
### Deferred re-rolling
|
|
|
|
Prompts where the model gets zero reward for all generations are buffered and re-injected into later batches, when the model may have improved enough to produce useful completions.
|
|
|
|
```yaml
|
|
trl:
|
|
reroll_start_fraction: 0.5 # Start re-rolling after 50% of training
|
|
reroll_max_groups: 1 # Max groups to replace per batch
|
|
```
|
|
|
|
Set `reroll_start_fraction: 1.0` to disable. This is most useful for tasks where the model starts weak but steadily improves.
|
|
|
|
### Parallel reward workers
|
|
|
|
Reward functions that use `signal.alarm()` (like `math_verify`) only work in the main thread. Parallel reward workers run each function in its own subprocess:
|
|
|
|
```yaml
|
|
trl:
|
|
reward_num_workers: 4
|
|
```
|
|
|
|
Work is sharded across workers by prompt group. For simple reward functions, a single worker is usually sufficient -- the overhead of IPC can exceed the computation time.
|
|
|
|
|
|
## Importance Sampling and Off-Policy Correction
|
|
|
|
When using async prefetch, completions are generated from a slightly older policy. IS correction adjusts the gradient to account for this mismatch.
|
|
|
|
```yaml
|
|
trl:
|
|
vllm_importance_sampling_correction: true
|
|
importance_sampling_level: token # 'token' recommended (especially with Liger kernel)
|
|
off_policy_mask_threshold: 0.5 # KL threshold — masks sequences that are too off-policy
|
|
```
|
|
|
|
Use `token` level IS. Sequence-level has numerical issues with Liger's chunked computation. The `off_policy_mask_threshold` (OPSM) is a safety net that drops sequences where KL divergence exceeds the threshold — 0.5 is a reasonable starting point.
|
|
|
|
For detailed coverage of IS modes (`token_mask`, `token_truncate`, etc.), capping, and bias-corrected KL, see [vLLM Serving — IS Correction](vllm_serving.qmd#sec-weight-sync).
|
|
|
|
|
|
## Scaling
|
|
|
|
### FP8 training
|
|
|
|
FP8 quantization halves model VRAM usage with minimal impact on training quality. It does not significantly speed up computation for small models but allows larger models to fit in memory.
|
|
|
|
```yaml
|
|
fp8: true
|
|
torch_compile: true
|
|
```
|
|
|
|
:::{.callout-warning}
|
|
FP8 requires patching for zero-padding edge cases. The `act_quant_kernel` can produce NaN when input is all zeros (padding positions). If you see NaN in grad norms, check whether your padding token embedding is non-zero.
|
|
:::
|
|
|
|
### FSDP (Fully Sharded Data Parallel)
|
|
|
|
FSDP distributes model parameters across multiple GPUs for training while vLLM runs on a separate GPU:
|
|
|
|
```yaml
|
|
fsdp:
|
|
- full_shard
|
|
- auto_wrap
|
|
fsdp_config:
|
|
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
|
gradient_checkpointing_kwargs:
|
|
use_reentrant: false
|
|
```
|
|
|
|
Launch with:
|
|
|
|
```bash
|
|
# GPU 0: vLLM
|
|
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
|
|
|
# GPUs 0,1: Training (FSDP will use both visible GPUs)
|
|
CUDA_VISIBLE_DEVICES=0,1 axolotl train config.yaml
|
|
```
|
|
|
|
:::{.callout-warning}
|
|
`async_prefetch: true` can deadlock with FSDP because background threads perform unsynchronized FSDP collectives across ranks. With multi-GPU FSDP, only rank 0 generates in the background thread and results are broadcast to all ranks. If you still see hangs, set `async_prefetch: false`.
|
|
:::
|
|
|
|
### DeepSpeed ZeRO-3
|
|
|
|
```yaml
|
|
deepspeed: deepspeed_configs/zero3_bf16.json
|
|
gradient_checkpointing_kwargs:
|
|
use_reentrant: true # Required -- non-reentrant causes CheckpointError with ZeRO-3
|
|
```
|
|
|
|
:::{.callout-note}
|
|
DeepSpeed ZeRO-3 requires `use_reentrant: true` for gradient checkpointing. This is the opposite of the FSDP recommendation. Non-reentrant checkpointing causes tensor metadata mismatches during recomputation with ZeRO-3's parameter partitioning.
|
|
:::
|
|
|
|
### Multi-GPU considerations
|
|
|
|
| Concern | Recommendation |
|
|
|---------|---------------|
|
|
| vLLM GPU allocation | Dedicate one or more GPUs to vLLM; do not share with trainer GPUs |
|
|
| Weight sync contention | Use `vllm_lora_sync: true` to avoid NCCL contention between training and vLLM |
|
|
| FSDP + async | Use `async_prefetch: false` or rely on rank-0-only background generation |
|
|
| DeepSpeed + gradient checkpoint | Must use `use_reentrant: true` |
|
|
| OOM during scoring | Reduce `micro_batch_size` or `num_generations`. The logits tensor scales with `batch_size * vocab_size` |
|
|
|
|
|
|
## Monitoring and Debugging
|
|
|
|
For detailed metric ranges, failure diagnosis, and OOM debugging, see [Training Stability & Debugging](training_stability.qmd).
|
|
|
|
Quick health checks during GRPO training:
|
|
|
|
- `rewards/*/mean` should be > 0.15 within 20 steps — if it stays at 0, test your reward function standalone
|
|
- `reward_std` should be > 0 on most steps — all-zero means no learning signal
|
|
- `entropy` in 0.05-0.5 — below 0.01 suggests mode collapse
|
|
- `grad_norm` in 0.001-1.0 — > 10 is unstable, 0.0 is expected when zero-advantage skip fires
|
|
|
|
:::{.callout-tip}
|
|
Pipe training output to a log file: `axolotl train config.yaml 2>&1 | tee /tmp/training.log`
|
|
:::
|
|
|
|
|
|
## Configuration Reference
|
|
|
|
All GRPO-specific options live under the `trl:` key in your config. Standard training options (`learning_rate`, `micro_batch_size`, etc.) are set at the top level as usual.
|
|
|
|
### Core GRPO
|
|
|
|
| Option | Type | Default | Description |
|
|
|--------|------|---------|-------------|
|
|
| `use_vllm` | bool | `false` | Enable vLLM for generation |
|
|
| `vllm_mode` | `"server"` or `"colocate"` | `null` | vLLM deployment mode |
|
|
| `vllm_server_host` | str | `"0.0.0.0"` | vLLM server hostname |
|
|
| `vllm_server_port` | int | `8000` | vLLM server port |
|
|
| `vllm_server_timeout` | int | `null` | Timeout (seconds) for vLLM responses |
|
|
| `num_generations` | int | `null` | Completions generated per prompt |
|
|
| `generation_batch_size` | int | `null` | Number of unique prompts per generation step |
|
|
| `max_completion_length` | int | `null` | Maximum tokens per completion |
|
|
| `beta` | float | `null` | KL penalty coefficient |
|
|
| `num_iterations` | int | `null` | Iterations per batch (mu in the GRPO paper) |
|
|
| `epsilon` | float | `null` | PPO clipping lower bound |
|
|
| `epsilon_high` | float | `null` | PPO clipping upper bound |
|
|
| `loss_type` | str | `null` | Loss formulation: `grpo`, `bnpo`, or `dr_grpo` |
|
|
| `scale_rewards` | bool | `true` | Normalize rewards by standard deviation |
|
|
| `mask_truncated_completions` | bool | `false` | Exclude truncated completions from loss |
|
|
|
|
### Reward functions
|
|
|
|
| Option | Type | Default | Description |
|
|
|--------|------|---------|-------------|
|
|
| `reward_funcs` | list[str] | `null` | Import paths to reward functions or HF model IDs |
|
|
| `reward_weights` | list[float] | `null` | Relative weights for each reward function |
|
|
| `multi_objective_aggregation` | str | `null` | `"sum_then_normalize"` (GRPO) or `"normalize_then_sum"` (GDPO) |
|
|
| `rollout_func` | str | `null` | Import path to custom rollout function for OpenEnv-style tasks |
|
|
|
|
### Generation parameters
|
|
|
|
| Option | Type | Default | Description |
|
|
|--------|------|---------|-------------|
|
|
| `temperature` | float | `null` | Sampling temperature |
|
|
| `top_p` | float | `null` | Nucleus sampling probability |
|
|
| `top_k` | int | `null` | Top-k sampling |
|
|
| `min_p` | float | `null` | Minimum probability threshold |
|
|
| `repetition_penalty` | float | `null` | Penalty for repeated tokens |
|
|
| `generation_kwargs` | dict | `null` | Additional vLLM SamplingParams (e.g., `stop_token_ids`) |
|
|
| `chat_template_kwargs` | dict | `null` | Chat template kwargs (e.g., `{enable_thinking: false}`) |
|
|
| `vllm_guided_decoding_regex` | str | `null` | Regex constraint for guided decoding |
|
|
|
|
### Async pipeline
|
|
|
|
| Option | Type | Default | Description |
|
|
|--------|------|---------|-------------|
|
|
| `use_data_producer` | bool | `false` | Enable data producer protocol (required for async features) |
|
|
| `async_prefetch` | bool | `false` | Generate next batch in background thread |
|
|
| `prefetch_depth` | int | `null` | Number of batches to prefetch ahead |
|
|
| `vllm_sync_interval` | int | `null` | Sync LoRA weights to vLLM every N steps |
|
|
| `vllm_lora_sync` | bool | `false` | Use filesystem LoRA sync instead of NCCL merge |
|
|
| `streaming_partial_batch` | bool | `null` | Score prompt groups incrementally |
|
|
| `streaming_min_groups` | int | `null` | Minimum groups per streaming chunk |
|
|
| `skip_zero_advantage_batches` | bool | `true` | Skip micro-batches with zero learning signal |
|
|
| `reward_num_workers` | int | `1` | Subprocess workers for reward computation |
|
|
| `vllm_enable_sleep_mode` | bool | `null` | Offload vLLM weights when idle (colocate mode) |
|
|
|
|
### Importance sampling
|
|
|
|
| Option | Type | Default | Description |
|
|
|--------|------|---------|-------------|
|
|
| `vllm_importance_sampling_correction` | bool | `null` | Enable IS correction for async distribution shift |
|
|
| `importance_sampling_level` | `"token"` or `"sequence"` | `null` | Granularity of IS ratios. Use `token` with Liger |
|
|
| `vllm_importance_sampling_mode` | str | `null` | `token_mask`, `token_truncate`, `sequence_mask`, or `sequence_truncate` |
|
|
| `vllm_importance_sampling_cap` | float | `null` | Cap C for IS ratio clipping/masking |
|
|
| `off_policy_mask_threshold` | float | `null` | KL threshold for off-policy sequence masking (OPSM) |
|
|
| `use_bias_correction_kl` | bool | `null` | Apply IS correction to KL divergence term |
|
|
|
|
### Replay and re-roll
|
|
|
|
| Option | Type | Default | Description |
|
|
|--------|------|---------|-------------|
|
|
| `replay_buffer_size` | int | `0` | Max cached high-signal groups. 0 = disabled |
|
|
| `replay_recompute_logps` | bool | `true` | Recompute log-probs for replayed data with current model |
|
|
| `reroll_start_fraction` | float | `1.0` | Start re-rolling failed prompts after this fraction of training. 1.0 = disabled |
|
|
| `reroll_max_groups` | int | `1` | Max prompt groups to replace with re-rolls per batch |
|
|
|
|
### Reference model
|
|
|
|
| Option | Type | Default | Description |
|
|
|--------|------|---------|-------------|
|
|
| `sync_ref_model` | bool | `false` | Periodically sync reference model with training model |
|
|
| `ref_model_mixup_alpha` | float | `0.9` | EMA coefficient for reference model sync |
|
|
| `ref_model_sync_steps` | int | `64` | Sync reference model every N steps |
|
|
|
|
### Logging
|
|
|
|
| Option | Type | Default | Description |
|
|
|--------|------|---------|-------------|
|
|
| `log_completions` | bool | `false` | Log sample completions to W&B |
|
|
| `num_completions_to_print` | int | `null` | Number of completions to print per step |
|
|
| `use_liger_loss` | bool | `null` | Use Liger fused kernel for GRPO loss (reduces VRAM) |
|