From 10946afae7877dcb8d8a6d574eb4eec3cf411e2a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 2 Aug 2025 11:19:24 -0400 Subject: [PATCH] fixes for spinning up vllm service for grpo (#3001) --- .github/workflows/main.yml | 5 ++- src/axolotl/cli/vllm_serve.py | 78 +++++------------------------------ 2 files changed, 13 insertions(+), 70 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 444ebfde8..891300246 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -24,12 +24,13 @@ jobs: cuda_version: 12.6.3 python_version: "3.11" pytorch: 2.7.0 - axolotl_extras: vllm + axolotl_extras: - cuda: 126 cuda_version: 12.6.3 python_version: "3.11" pytorch: 2.7.1 - axolotl_extras: + axolotl_extras: vllm + is_latest: true - cuda: 128 cuda_version: 12.8.1 python_version: "3.11" diff --git a/src/axolotl/cli/vllm_serve.py b/src/axolotl/cli/vllm_serve.py index f092cc59a..cf687bea2 100644 --- a/src/axolotl/cli/vllm_serve.py +++ b/src/axolotl/cli/vllm_serve.py @@ -2,12 +2,10 @@ CLI to start the vllm server for online RL """ -import os from dataclasses import dataclass, field from pathlib import Path from typing import Union -import trl from trl.scripts.vllm_serve import ScriptArguments from axolotl.cli.config import load_cfg @@ -42,13 +40,17 @@ def do_vllm_serve( serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve") vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main") + tensor_parallel_size = 1 + data_parallel_size = 1 - tensor_parallel_size = ( - cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size - ) - data_parallel_size = ( - cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size - ) + if cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size: + tensor_parallel_size = ( + cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size + ) + if cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size: + data_parallel_size = ( + cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size + ) host = cli_args.get("host") or cfg.vllm.host port = cli_args.get("port") or cfg.vllm.port gpu_memory_utilization = ( @@ -81,63 +83,3 @@ def do_vllm_serve( enable_reasoning=enable_reasoning, ) vllm_serve_main(vllm_script_args) - - -def patch_vllm_worker(): - from multiprocessing.connection import Connection - - from vllm import LLM - - def llm_worker( - script_args: AxolotlScriptArguments, - data_parallel_rank: int, - master_port: int, - connection: Connection, - ) -> None: - # Set required environment variables for DP to work with vLLM - os.environ["VLLM_DP_RANK"] = str(data_parallel_rank) - os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank) - os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size) - os.environ["VLLM_DP_MASTER_PORT"] = str(master_port) - - llm = LLM( - model=script_args.model, - revision=script_args.revision, - tensor_parallel_size=script_args.tensor_parallel_size, - gpu_memory_utilization=script_args.gpu_memory_utilization, - enforce_eager=script_args.enforce_eager, - dtype=script_args.dtype, - # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can - # directly reuse the KV cache if it shares the same prefix with one of the existing queries. - # This is particularly useful here because we generate completions from the same prompts. - enable_prefix_caching=script_args.enable_prefix_caching, - kv_cache_dtype=script_args.kv_cache_dtype, - max_model_len=script_args.max_model_len, - worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension", - enable_reasoning=script_args.enable_reasoning, - reasoning_parser=script_args.reasoning_parser, - ) - - # Send ready signal to parent process - connection.send({"status": "ready"}) - - while True: - # Wait for commands from the parent process - try: - command = connection.recv() - except KeyboardInterrupt: - llm.collective_rpc(method="close_communicator") - break - - # Handle commands - if command["type"] in ["call", "fire_and_forget"]: - method_name = command["method"] - args, kwargs = command.get("args", ()), command.get("kwargs", {}) - method = getattr(llm, method_name) - result = method(*args, **kwargs) - if command["type"] == "call": - connection.send(result) - elif command["type"] == "shutdown": - break - - trl.scripts.vllm_serve.llm_worker = llm_worker