* docs: comprehensive documentation improvements for humans and agents New human docs: - grpo.qmd: GRPO deep dive (async, rewards, IS correction, scaling) - ebft.qmd: EBFT guide (structured/strided modes, feature extraction) - choosing_method.qmd: decision tree for SFT vs LoRA vs DPO vs GRPO - vllm_serving.qmd: vLLM setup for GRPO (server/colocate, LoRA sync) - training_stability.qmd: monitoring, NaN debugging, OOM, healthy metrics New agent docs: - AGENTS_SFT.md: agent reference for supervised fine-tuning - AGENTS_DPO.md: agent reference for preference learning (DPO/KTO/ORPO) Updated existing docs: - rlhf.qmd: cross-references to new GRPO/EBFT/choosing-method guides - getting-started.qmd: reorganized Next Steps with links to new guides - debugging.qmd: link to training stability guide - _quarto.yml: added new pages to sidebar navigation Removed: - bak.agents.md: stale backup that confused agents * docs: trim duplicated generic config from AGENTS_DPO.md Remove boilerplate training params (optimizer, gradient_checkpointing, flash_attention, etc.) from each method template. These are not preference-learning-specific and are already covered in AGENTS_SFT.md. Config templates now show only method-specific fields with a reference to AGENTS_SFT.md for the rest. * docs: deduplicate across new doc pages - grpo.qmd: collapse vLLM setup section to brief config + link to vllm_serving.qmd; collapse IS correction to essentials + link; replace full monitoring tables with summary + link to training_stability.qmd - vllm_serving.qmd: remove duplicated async/IS config reference tables (already in grpo.qmd config reference); replace full example config with link to grpo.qmd quick start - ebft.qmd: trim generic training params in quick start config * fix: train scripts * feat: split files into cleaner parts * fix: cleanup pretraining docs --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
319 lines
14 KiB
Plaintext
319 lines
14 KiB
Plaintext
---
|
|
title: "vLLM Serving for GRPO Training"
|
|
description: "How to configure and run vLLM as a generation backend for GRPO reinforcement learning in Axolotl."
|
|
format:
|
|
html:
|
|
toc: true
|
|
toc-depth: 3
|
|
number-sections: true
|
|
execute:
|
|
enabled: false
|
|
---
|
|
|
|
## Overview {#sec-overview}
|
|
|
|
GRPO (Group Relative Policy Optimization) trains a language model by generating completions, scoring them with reward functions, and updating the policy to favor higher-reward outputs. The generation step is the bottleneck: producing thousands of tokens per training step with the policy model is slow using standard HuggingFace generation.
|
|
|
|
Axolotl uses [vLLM](https://github.com/vllm-project/vllm) as a high-throughput generation backend. vLLM runs as a separate process (either on a dedicated GPU or colocated on the training GPU) and serves completions via an HTTP API. The trainer sends prompts to vLLM, receives completions, scores them, and performs gradient updates.
|
|
|
|
```
|
|
┌──────────────────────┐ HTTP ┌──────────────────────┐
|
|
│ Trainer (GPU 1) │ ───────────────── │ vLLM Server (GPU 0)│
|
|
│ │ prompts/compls │ │
|
|
│ - Policy model │ ◄──────────────── │ - Same base model │
|
|
│ - Reward scoring │ │ - Fast generation │
|
|
│ - Gradient updates │ weight sync │ - LoRA adapter │
|
|
│ - LoRA adapter │ ─────────────────►│ (periodically │
|
|
│ │ (every N steps) │ updated) │
|
|
└──────────────────────┘ └──────────────────────┘
|
|
```
|
|
|
|
::: {.callout-important}
|
|
vLLM must serve the **same base model** specified in your training config. If the models do not match, weight synchronization will silently produce incorrect results.
|
|
:::
|
|
|
|
## Server Mode {#sec-server-mode}
|
|
|
|
Server mode runs vLLM as an external process on dedicated GPU(s). This is the recommended configuration for most setups.
|
|
|
|
### Starting the Server
|
|
|
|
Use the `axolotl vllm-serve` command with your training config:
|
|
|
|
```bash
|
|
# Terminal 1: Start vLLM on GPU 0
|
|
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve grpo_config.yaml
|
|
```
|
|
|
|
```bash
|
|
# Terminal 2: Start training on GPU 1
|
|
CUDA_VISIBLE_DEVICES=1 axolotl train grpo_config.yaml
|
|
```
|
|
|
|
The server reads vLLM settings from the `vllm:` section of your config and starts an HTTP server (default: `http://0.0.0.0:8000`).
|
|
|
|
::: {.callout-tip}
|
|
Use `tmux` or `screen` to manage the vLLM server process. Typical startup time is 30-90 seconds depending on model size and whether CUDA graphs are captured.
|
|
:::
|
|
|
|
### Minimal Server Config
|
|
|
|
```yaml
|
|
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
|
|
|
vllm:
|
|
host: 0.0.0.0
|
|
port: 8000
|
|
gpu_memory_utilization: 0.85
|
|
dtype: auto
|
|
max_model_len: 4096
|
|
|
|
rl: grpo
|
|
trl:
|
|
use_vllm: true
|
|
vllm_server_host: 0.0.0.0
|
|
vllm_server_port: 8000
|
|
vllm_server_timeout: 300
|
|
```
|
|
|
|
### Multi-GPU vLLM
|
|
|
|
For larger models, use tensor parallelism across multiple GPUs:
|
|
|
|
```yaml
|
|
vllm:
|
|
tensor_parallel_size: 2
|
|
gpu_memory_utilization: 0.85
|
|
```
|
|
|
|
```bash
|
|
# vLLM on GPUs 2,3; training on GPUs 0,1
|
|
CUDA_VISIBLE_DEVICES=2,3 axolotl vllm-serve grpo_config.yaml
|
|
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo_config.yaml --num-processes 2
|
|
```
|
|
|
|
::: {.callout-note}
|
|
Due to how TRL maps vLLM device indices, the vLLM instance should use the **last** N GPUs (highest device indices), while training uses the first N.
|
|
:::
|
|
|
|
## Colocate Mode {#sec-colocate-mode}
|
|
|
|
Colocate mode runs vLLM on the same GPU as the trainer. This is useful when you only have a single GPU.
|
|
|
|
```yaml
|
|
trl:
|
|
use_vllm: true
|
|
vllm_mode: colocate
|
|
vllm_enable_sleep_mode: true
|
|
```
|
|
|
|
With `vllm_enable_sleep_mode: true`, vLLM offloads its VRAM allocation when not actively generating, freeing memory for training. When the trainer needs new completions, vLLM wakes up and reclaims VRAM.
|
|
|
|
::: {.callout-warning}
|
|
Colocate mode is significantly slower than server mode because generation and training cannot overlap. The GPU alternates between the two workloads. This mode is practical only for smaller models (up to ~3B on a 24 GB GPU).
|
|
:::
|
|
|
|
**When to use colocate mode:**
|
|
|
|
- You have exactly one GPU
|
|
- The model fits in memory with both vLLM and training active (with sleep mode), or is small enough to time-share
|
|
- You accept the performance tradeoff for simpler setup (no separate vLLM process to manage)
|
|
|
|
**When to use server mode:**
|
|
|
|
- You have two or more GPUs
|
|
- You want maximum throughput (generation overlaps with training via async prefetch)
|
|
- You are running larger models (7B+)
|
|
|
|
## LoRA Sync {#sec-lora-sync}
|
|
|
|
LoRA sync is the recommended weight synchronization method when training with LoRA adapters. Instead of merging adapter weights into the base model and broadcasting the full merged weights over NCCL, it saves only the LoRA adapter files to the filesystem and tells vLLM to load them natively.
|
|
|
|
### How It Works
|
|
|
|
1. The trainer calls `model.save_pretrained()` to write the LoRA adapter weights to a temporary directory
|
|
2. The trainer sends an HTTP POST to `/set_lora_adapter/` on the vLLM server
|
|
3. vLLM loads the adapter using its native LoRA support (Punica kernels)
|
|
4. Generation uses the updated adapter on the next request
|
|
|
|
### Benefits
|
|
|
|
- **Smaller sync payload**: Transfers ~40 MB of LoRA weights instead of ~1.4 GB+ of merged model weights (for a typical 0.5-3B model)
|
|
- **No NCCL communicator**: Eliminates the need for a cross-GPU NCCL communication channel, removing GPU contention between vLLM generation and weight sync
|
|
- **Faster sync**: ~200 ms per sync vs. 350 ms to 5+ seconds for NCCL merge sync
|
|
- **Simpler multi-GPU**: No need to set up NCCL groups between trainer and vLLM processes
|
|
|
|
### Configuration
|
|
|
|
```yaml
|
|
adapter: lora
|
|
lora_r: 32
|
|
lora_alpha: 64
|
|
lora_target_linear: true
|
|
|
|
trl:
|
|
vllm_lora_sync: true # Enables LoRA sync mode
|
|
vllm_sync_interval: 5 # Sync every 5 training steps
|
|
```
|
|
|
|
Setting `vllm_lora_sync: true` automatically selects the LoRA-aware vLLM serve script (`axolotl.scripts.vllm_serve_lora`). You do not need to set `vllm.serve_module` manually.
|
|
|
|
::: {.callout-important}
|
|
LoRA sync requires that you are training with a LoRA adapter (`adapter: lora` or `adapter: qlora`). It is not applicable to full fine-tuning.
|
|
:::
|
|
|
|
## Weight Synchronization {#sec-weight-sync}
|
|
|
|
During GRPO training, the policy model on the trainer is continuously updated via gradient steps. The vLLM server, however, still holds the old weights. Periodically, the trainer must push updated weights to vLLM so that future generations reflect the improved policy.
|
|
|
|
### Sync Interval
|
|
|
|
The `vllm_sync_interval` parameter controls how often weights are synced:
|
|
|
|
```yaml
|
|
trl:
|
|
vllm_sync_interval: 5 # Sync every 5 optimizer steps
|
|
```
|
|
|
|
**Tradeoffs:**
|
|
|
|
- **Lower interval** (e.g., 1-3): Fresher generations, better on-policy data, but more sync overhead per step
|
|
- **Higher interval** (e.g., 5-10): Less overhead, but generations become increasingly off-policy between syncs
|
|
- **Recommended**: 3-5 for most setups. Axolotl includes importance sampling correction (`vllm_importance_sampling_correction: true`) to handle mild distribution mismatch from stale vLLM weights.
|
|
|
|
### Sync Methods
|
|
|
|
| Method | Config | Payload | Mechanism | Typical Time |
|
|
|--------|--------|---------|-----------|-------------|
|
|
| **LoRA sync** | `vllm_lora_sync: true` | LoRA adapter only (~40 MB) | Filesystem + HTTP | ~200 ms |
|
|
| **NCCL merge sync** | Default (no lora_sync) | Full merged weights (~1.4 GB+) | HTTP trigger + NCCL broadcast | 350 ms - 5 s |
|
|
|
|
::: {.callout-tip}
|
|
If you are training with LoRA (which is recommended for GRPO), always enable `vllm_lora_sync: true`. The performance difference is substantial, especially as training progresses and NCCL contention increases.
|
|
:::
|
|
|
|
### Importance Sampling Correction
|
|
|
|
When vLLM weights are stale (between syncs), the generated data is slightly off-policy. Axolotl can correct for this:
|
|
|
|
```yaml
|
|
trl:
|
|
vllm_importance_sampling_correction: true
|
|
importance_sampling_level: token # 'token' or 'sequence'
|
|
off_policy_mask_threshold: 0.5 # KL threshold for masking stale sequences
|
|
```
|
|
|
|
- **Token-level IS** is recommended when using Liger kernel (sequence-level has numerical issues with chunked computation)
|
|
- **Off-policy sequence masking (OPSM)** drops sequences that have diverged too far from the current policy, providing a safety net against stale data
|
|
|
|
## Restart Requirements {#sec-restart}
|
|
|
|
::: {.callout-warning}
|
|
**vLLM must be restarted between training runs.** Weight syncs from a previous run leave the server in a corrupted state. If you start a new training run against a stale vLLM server, the model may fail to learn.
|
|
:::
|
|
|
|
### When to Restart
|
|
|
|
- Before every new training experiment
|
|
- After a training run crashes or is interrupted
|
|
- If you change the base model in your config
|
|
|
|
### How to Restart
|
|
|
|
Killing vLLM reliably requires terminating both the main process and its background EngineCore subprocess:
|
|
|
|
```bash
|
|
# Kill all vLLM-related processes
|
|
pkill -9 -f "vllm|EngineCore"
|
|
|
|
# Verify GPU memory is freed
|
|
nvidia-smi
|
|
|
|
# Restart the server
|
|
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve grpo_config.yaml
|
|
```
|
|
|
|
::: {.callout-tip}
|
|
A single `kill` often does not fully stop vLLM. Always use `kill -9` and verify with `nvidia-smi` that GPU memory has been released before restarting.
|
|
:::
|
|
|
|
### Health Check
|
|
|
|
The vLLM server exposes a health endpoint. Wait for it to return 200 before starting training:
|
|
|
|
```bash
|
|
# For the LoRA serve script (trailing slash required)
|
|
curl http://localhost:8000/health/
|
|
|
|
# For the default TRL serve script
|
|
curl http://localhost:8000/health
|
|
```
|
|
|
|
## Configuration Reference {#sec-config-reference}
|
|
|
|
### vLLM Server Options (`vllm:` section)
|
|
|
|
These control the vLLM server process started by `axolotl vllm-serve`.
|
|
|
|
| Option | Type | Default | Description |
|
|
|--------|------|---------|-------------|
|
|
| `host` | str | `0.0.0.0` | Host address for the vLLM server |
|
|
| `port` | int | `8000` | Port for the vLLM server |
|
|
| `device` | str | `auto` | Device to use for vLLM |
|
|
| `tensor_parallel_size` | int | `None` | Number of GPUs for tensor parallelism |
|
|
| `data_parallel_size` | int | `None` | Number of data parallel replicas |
|
|
| `gpu_memory_utilization` | float | `0.9` | Fraction of GPU memory for vLLM (0.0-1.0) |
|
|
| `dtype` | str | `auto` | Data type (`auto`, `float16`, `bfloat16`) |
|
|
| `max_model_len` | int | `None` | Maximum model context length. Set explicitly if the default is too large for your GPU |
|
|
| `enable_prefix_caching` | bool | `None` | Enable prefix caching for repeated prompt prefixes |
|
|
| `enable_reasoning` | bool | `None` | Enable reasoning mode for models with thinking tokens |
|
|
| `reasoning_parser` | str | `None` | Parser for reasoning output |
|
|
| `enforce_eager` | bool | `None` | Disable CUDA graph capture (required for some architectures like Qwen3.5 hybrid attention) |
|
|
| `serve_module` | str | `None` | Python module for vLLM serve script. Auto-set when `vllm_lora_sync: true` |
|
|
| `worker_extension_cls` | str | `None` | vLLM worker extension class for weight sync |
|
|
|
|
### Trainer vLLM Options (`trl:` section)
|
|
|
|
These control how the trainer interacts with vLLM.
|
|
|
|
| Option | Type | Default | Description |
|
|
|--------|------|---------|-------------|
|
|
| `use_vllm` | bool | `false` | Enable vLLM for generation |
|
|
| `vllm_mode` | str | `None` | `server` (external process) or `colocate` (same GPU) |
|
|
| `vllm_server_host` | str | `0.0.0.0` | Host of the vLLM server to connect to |
|
|
| `vllm_server_port` | int | `8000` | Port of the vLLM server to connect to |
|
|
| `vllm_server_timeout` | int | `None` | Timeout in seconds for vLLM requests |
|
|
| `vllm_lora_sync` | bool | `false` | Sync LoRA adapters via filesystem instead of NCCL merge |
|
|
| `vllm_sync_interval` | int | `None` | Sync weights every N optimizer steps |
|
|
| `vllm_enable_sleep_mode` | bool | `None` | Offload vLLM VRAM when idle (colocate mode) |
|
|
| `vllm_guided_decoding_regex` | str | `None` | Regex constraint for guided decoding |
|
|
|
|
For async pipeline and off-policy correction options, see the [GRPO Configuration Reference](grpo.qmd#configuration-reference).
|
|
|
|
## Complete Example {#sec-complete-example}
|
|
|
|
For a full working GRPO config including vLLM, LoRA sync, async generation, rewards, and dataset setup, see the [GRPO Quick Start](grpo.qmd#quick-start). That config includes all the vLLM settings covered in this guide.
|
|
|
|
```bash
|
|
# Terminal 1: Start vLLM
|
|
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve grpo_config.yaml
|
|
|
|
# Wait for health check to pass
|
|
curl http://localhost:8000/health/
|
|
|
|
# Terminal 2: Start training
|
|
CUDA_VISIBLE_DEVICES=1 axolotl train grpo_config.yaml
|
|
```
|
|
|
|
## Troubleshooting {#sec-troubleshooting}
|
|
|
|
| Problem | Likely Cause | Solution |
|
|
|---------|-------------|----------|
|
|
| Training hangs waiting for vLLM | Server not started or wrong port | Check `curl http://localhost:8000/health/` and verify `vllm_server_host`/`vllm_server_port` match |
|
|
| OOM on vLLM GPU | `gpu_memory_utilization` too high or `max_model_len` too large | Reduce `gpu_memory_utilization` to 0.7 or set `max_model_len` explicitly |
|
|
| OOM on training GPU | Batch too large for policy logprobs | Reduce `micro_batch_size` or `num_generations` |
|
|
| Accuracy stays at zero | Stale vLLM from previous run | Restart vLLM: `pkill -9 -f "vllm\|EngineCore"`, verify with `nvidia-smi`, restart |
|
|
| `ResponseValidationError` from vLLM | Missing logprobs in response | Ensure you are using the correct serve module (auto-selected with `vllm_lora_sync: true`) |
|
|
| Weight sync takes 5+ seconds | NCCL contention with vLLM generation | Switch to `vllm_lora_sync: true` to eliminate NCCL |
|
|
| `async_prefetch` deadlocks with FSDP | Background threads run unsynchronized FSDP collectives | Set `async_prefetch: false` when using FSDP or DeepSpeed multi-GPU |
|