--- title: "vLLM Serving for GRPO Training" description: "How to configure and run vLLM as a generation backend for GRPO reinforcement learning in Axolotl." format: html: toc: true toc-depth: 3 number-sections: true execute: enabled: false --- ## Overview {#sec-overview} GRPO (Group Relative Policy Optimization) trains a language model by generating completions, scoring them with reward functions, and updating the policy to favor higher-reward outputs. The generation step is the bottleneck: producing thousands of tokens per training step with the policy model is slow using standard HuggingFace generation. Axolotl uses [vLLM](https://github.com/vllm-project/vllm) as a high-throughput generation backend. vLLM runs as a separate process (either on a dedicated GPU or colocated on the training GPU) and serves completions via an HTTP API. The trainer sends prompts to vLLM, receives completions, scores them, and performs gradient updates. ``` ┌──────────────────────┐ HTTP ┌──────────────────────┐ │ Trainer (GPU 1) │ ───────────────── │ vLLM Server (GPU 0)│ │ │ prompts/compls │ │ │ - Policy model │ ◄──────────────── │ - Same base model │ │ - Reward scoring │ │ - Fast generation │ │ - Gradient updates │ weight sync │ - LoRA adapter │ │ - LoRA adapter │ ─────────────────►│ (periodically │ │ │ (every N steps) │ updated) │ └──────────────────────┘ └──────────────────────┘ ``` ::: {.callout-important} vLLM must serve the **same base model** specified in your training config. If the models do not match, weight synchronization will silently produce incorrect results. ::: ## Server Mode {#sec-server-mode} Server mode runs vLLM as an external process on dedicated GPU(s). This is the recommended configuration for most setups. ### Starting the Server Use the `axolotl vllm-serve` command with your training config: ```bash # Terminal 1: Start vLLM on GPU 0 CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve grpo_config.yaml ``` ```bash # Terminal 2: Start training on GPU 1 CUDA_VISIBLE_DEVICES=1 axolotl train grpo_config.yaml ``` The server reads vLLM settings from the `vllm:` section of your config and starts an HTTP server (default: `http://0.0.0.0:8000`). ::: {.callout-tip} Use `tmux` or `screen` to manage the vLLM server process. Typical startup time is 30-90 seconds depending on model size and whether CUDA graphs are captured. ::: ### Minimal Server Config ```yaml base_model: Qwen/Qwen2.5-1.5B-Instruct vllm: host: 0.0.0.0 port: 8000 gpu_memory_utilization: 0.85 dtype: auto max_model_len: 4096 rl: grpo trl: use_vllm: true vllm_server_host: 0.0.0.0 vllm_server_port: 8000 vllm_server_timeout: 300 ``` ### Multi-GPU vLLM For larger models, use tensor parallelism across multiple GPUs: ```yaml vllm: tensor_parallel_size: 2 gpu_memory_utilization: 0.85 ``` ```bash # vLLM on GPUs 2,3; training on GPUs 0,1 CUDA_VISIBLE_DEVICES=2,3 axolotl vllm-serve grpo_config.yaml CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo_config.yaml --num-processes 2 ``` ::: {.callout-note} Due to how TRL maps vLLM device indices, the vLLM instance should use the **last** N GPUs (highest device indices), while training uses the first N. ::: ## Colocate Mode {#sec-colocate-mode} Colocate mode runs vLLM on the same GPU as the trainer. This is useful when you only have a single GPU. ```yaml trl: use_vllm: true vllm_mode: colocate vllm_enable_sleep_mode: true ``` With `vllm_enable_sleep_mode: true`, vLLM offloads its VRAM allocation when not actively generating, freeing memory for training. When the trainer needs new completions, vLLM wakes up and reclaims VRAM. ::: {.callout-warning} Colocate mode is significantly slower than server mode because generation and training cannot overlap. The GPU alternates between the two workloads. This mode is practical only for smaller models (up to ~3B on a 24 GB GPU). ::: **When to use colocate mode:** - You have exactly one GPU - The model fits in memory with both vLLM and training active (with sleep mode), or is small enough to time-share - You accept the performance tradeoff for simpler setup (no separate vLLM process to manage) **When to use server mode:** - You have two or more GPUs - You want maximum throughput (generation overlaps with training via async prefetch) - You are running larger models (7B+) ## LoRA Sync {#sec-lora-sync} LoRA sync is the recommended weight synchronization method when training with LoRA adapters. Instead of merging adapter weights into the base model and broadcasting the full merged weights over NCCL, it saves only the LoRA adapter files to the filesystem and tells vLLM to load them natively. ### How It Works 1. The trainer calls `model.save_pretrained()` to write the LoRA adapter weights to a temporary directory 2. The trainer sends an HTTP POST to `/set_lora_adapter/` on the vLLM server 3. vLLM loads the adapter using its native LoRA support (Punica kernels) 4. Generation uses the updated adapter on the next request ### Benefits - **Smaller sync payload**: Transfers ~40 MB of LoRA weights instead of ~1.4 GB+ of merged model weights (for a typical 0.5-3B model) - **No NCCL communicator**: Eliminates the need for a cross-GPU NCCL communication channel, removing GPU contention between vLLM generation and weight sync - **Faster sync**: ~200 ms per sync vs. 350 ms to 5+ seconds for NCCL merge sync - **Simpler multi-GPU**: No need to set up NCCL groups between trainer and vLLM processes ### Configuration ```yaml adapter: lora lora_r: 32 lora_alpha: 64 lora_target_linear: true trl: vllm_lora_sync: true # Enables LoRA sync mode vllm_sync_interval: 5 # Sync every 5 training steps ``` Setting `vllm_lora_sync: true` automatically selects the LoRA-aware vLLM serve script (`axolotl.scripts.vllm_serve_lora`). You do not need to set `vllm.serve_module` manually. ::: {.callout-important} LoRA sync requires that you are training with a LoRA adapter (`adapter: lora` or `adapter: qlora`). It is not applicable to full fine-tuning. ::: ## Weight Synchronization {#sec-weight-sync} During GRPO training, the policy model on the trainer is continuously updated via gradient steps. The vLLM server, however, still holds the old weights. Periodically, the trainer must push updated weights to vLLM so that future generations reflect the improved policy. ### Sync Interval The `vllm_sync_interval` parameter controls how often weights are synced: ```yaml trl: vllm_sync_interval: 5 # Sync every 5 optimizer steps ``` **Tradeoffs:** - **Lower interval** (e.g., 1-3): Fresher generations, better on-policy data, but more sync overhead per step - **Higher interval** (e.g., 5-10): Less overhead, but generations become increasingly off-policy between syncs - **Recommended**: 3-5 for most setups. Axolotl includes importance sampling correction (`vllm_importance_sampling_correction: true`) to handle mild distribution mismatch from stale vLLM weights. ### Sync Methods | Method | Config | Payload | Mechanism | Typical Time | |--------|--------|---------|-----------|-------------| | **LoRA sync** | `vllm_lora_sync: true` | LoRA adapter only (~40 MB) | Filesystem + HTTP | ~200 ms | | **NCCL merge sync** | Default (no lora_sync) | Full merged weights (~1.4 GB+) | HTTP trigger + NCCL broadcast | 350 ms - 5 s | ::: {.callout-tip} If you are training with LoRA (which is recommended for GRPO), always enable `vllm_lora_sync: true`. The performance difference is substantial, especially as training progresses and NCCL contention increases. ::: ### Importance Sampling Correction When vLLM weights are stale (between syncs), the generated data is slightly off-policy. Axolotl can correct for this: ```yaml trl: vllm_importance_sampling_correction: true importance_sampling_level: token # 'token' or 'sequence' off_policy_mask_threshold: 0.5 # KL threshold for masking stale sequences ``` - **Token-level IS** is recommended when using Liger kernel (sequence-level has numerical issues with chunked computation) - **Off-policy sequence masking (OPSM)** drops sequences that have diverged too far from the current policy, providing a safety net against stale data ## Restart Requirements {#sec-restart} ::: {.callout-warning} **vLLM must be restarted between training runs.** Weight syncs from a previous run leave the server in a corrupted state. If you start a new training run against a stale vLLM server, the model may fail to learn. ::: ### When to Restart - Before every new training experiment - After a training run crashes or is interrupted - If you change the base model in your config ### How to Restart Killing vLLM reliably requires terminating both the main process and its background EngineCore subprocess: ```bash # Kill all vLLM-related processes pkill -9 -f "vllm|EngineCore" # Verify GPU memory is freed nvidia-smi # Restart the server CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve grpo_config.yaml ``` ::: {.callout-tip} A single `kill` often does not fully stop vLLM. Always use `kill -9` and verify with `nvidia-smi` that GPU memory has been released before restarting. ::: ### Health Check The vLLM server exposes a health endpoint. Wait for it to return 200 before starting training: ```bash # For the LoRA serve script (trailing slash required) curl http://localhost:8000/health/ # For the default TRL serve script curl http://localhost:8000/health ``` ## Configuration Reference {#sec-config-reference} ### vLLM Server Options (`vllm:` section) These control the vLLM server process started by `axolotl vllm-serve`. | Option | Type | Default | Description | |--------|------|---------|-------------| | `host` | str | `0.0.0.0` | Host address for the vLLM server | | `port` | int | `8000` | Port for the vLLM server | | `device` | str | `auto` | Device to use for vLLM | | `tensor_parallel_size` | int | `None` | Number of GPUs for tensor parallelism | | `data_parallel_size` | int | `None` | Number of data parallel replicas | | `gpu_memory_utilization` | float | `0.9` | Fraction of GPU memory for vLLM (0.0-1.0) | | `dtype` | str | `auto` | Data type (`auto`, `float16`, `bfloat16`) | | `max_model_len` | int | `None` | Maximum model context length. Set explicitly if the default is too large for your GPU | | `enable_prefix_caching` | bool | `None` | Enable prefix caching for repeated prompt prefixes | | `enable_reasoning` | bool | `None` | Enable reasoning mode for models with thinking tokens | | `reasoning_parser` | str | `None` | Parser for reasoning output | | `enforce_eager` | bool | `None` | Disable CUDA graph capture (required for some architectures like Qwen3.5 hybrid attention) | | `serve_module` | str | `None` | Python module for vLLM serve script. Auto-set when `vllm_lora_sync: true` | | `worker_extension_cls` | str | `None` | vLLM worker extension class for weight sync | ### Trainer vLLM Options (`trl:` section) These control how the trainer interacts with vLLM. | Option | Type | Default | Description | |--------|------|---------|-------------| | `use_vllm` | bool | `false` | Enable vLLM for generation | | `vllm_mode` | str | `None` | `server` (external process) or `colocate` (same GPU) | | `vllm_server_host` | str | `0.0.0.0` | Host of the vLLM server to connect to | | `vllm_server_port` | int | `8000` | Port of the vLLM server to connect to | | `vllm_server_timeout` | int | `None` | Timeout in seconds for vLLM requests | | `vllm_lora_sync` | bool | `false` | Sync LoRA adapters via filesystem instead of NCCL merge | | `vllm_sync_interval` | int | `None` | Sync weights every N optimizer steps | | `vllm_enable_sleep_mode` | bool | `None` | Offload vLLM VRAM when idle (colocate mode) | | `vllm_guided_decoding_regex` | str | `None` | Regex constraint for guided decoding | For async pipeline and off-policy correction options, see the [GRPO Configuration Reference](grpo.qmd#configuration-reference). ## Complete Example {#sec-complete-example} For a full working GRPO config including vLLM, LoRA sync, async generation, rewards, and dataset setup, see the [GRPO Quick Start](grpo.qmd#quick-start). That config includes all the vLLM settings covered in this guide. ```bash # Terminal 1: Start vLLM CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve grpo_config.yaml # Wait for health check to pass curl http://localhost:8000/health/ # Terminal 2: Start training CUDA_VISIBLE_DEVICES=1 axolotl train grpo_config.yaml ``` ## Troubleshooting {#sec-troubleshooting} | Problem | Likely Cause | Solution | |---------|-------------|----------| | Training hangs waiting for vLLM | Server not started or wrong port | Check `curl http://localhost:8000/health/` and verify `vllm_server_host`/`vllm_server_port` match | | OOM on vLLM GPU | `gpu_memory_utilization` too high or `max_model_len` too large | Reduce `gpu_memory_utilization` to 0.7 or set `max_model_len` explicitly | | OOM on training GPU | Batch too large for policy logprobs | Reduce `micro_batch_size` or `num_generations` | | Accuracy stays at zero | Stale vLLM from previous run | Restart vLLM: `pkill -9 -f "vllm\|EngineCore"`, verify with `nvidia-smi`, restart | | `ResponseValidationError` from vLLM | Missing logprobs in response | Ensure you are using the correct serve module (auto-selected with `vllm_lora_sync: true`) | | Weight sync takes 5+ seconds | NCCL contention with vLLM generation | Switch to `vllm_lora_sync: true` to eliminate NCCL | | `async_prefetch` deadlocks with FSDP | Background threads run unsynchronized FSDP collectives | Set `async_prefetch: false` when using FSDP or DeepSpeed multi-GPU |