fixes for spinning up vllm service for grpo (#3001)

This commit is contained in:
Wing Lian
2025-08-02 11:19:24 -04:00
committed by GitHub
parent 5639552064
commit 10946afae7
2 changed files with 13 additions and 70 deletions

View File

@@ -24,12 +24,13 @@ jobs:
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.0 pytorch: 2.7.0
axolotl_extras: vllm axolotl_extras:
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.7.1 pytorch: 2.7.1
axolotl_extras: axolotl_extras: vllm
is_latest: true
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"

View File

@@ -2,12 +2,10 @@
CLI to start the vllm server for online RL CLI to start the vllm server for online RL
""" """
import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
import trl
from trl.scripts.vllm_serve import ScriptArguments from trl.scripts.vllm_serve import ScriptArguments
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
@@ -42,13 +40,17 @@ def do_vllm_serve(
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve") serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main") vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
tensor_parallel_size = 1
data_parallel_size = 1
tensor_parallel_size = ( if cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size:
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size tensor_parallel_size = (
) cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
data_parallel_size = ( )
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size if cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size:
) data_parallel_size = (
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
)
host = cli_args.get("host") or cfg.vllm.host host = cli_args.get("host") or cfg.vllm.host
port = cli_args.get("port") or cfg.vllm.port port = cli_args.get("port") or cfg.vllm.port
gpu_memory_utilization = ( gpu_memory_utilization = (
@@ -81,63 +83,3 @@ def do_vllm_serve(
enable_reasoning=enable_reasoning, enable_reasoning=enable_reasoning,
) )
vllm_serve_main(vllm_script_args) vllm_serve_main(vllm_script_args)
def patch_vllm_worker():
from multiprocessing.connection import Connection
from vllm import LLM
def llm_worker(
script_args: AxolotlScriptArguments,
data_parallel_rank: int,
master_port: int,
connection: Connection,
) -> None:
# Set required environment variables for DP to work with vLLM
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
llm = LLM(
model=script_args.model,
revision=script_args.revision,
tensor_parallel_size=script_args.tensor_parallel_size,
gpu_memory_utilization=script_args.gpu_memory_utilization,
enforce_eager=script_args.enforce_eager,
dtype=script_args.dtype,
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
# This is particularly useful here because we generate completions from the same prompts.
enable_prefix_caching=script_args.enable_prefix_caching,
kv_cache_dtype=script_args.kv_cache_dtype,
max_model_len=script_args.max_model_len,
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
enable_reasoning=script_args.enable_reasoning,
reasoning_parser=script_args.reasoning_parser,
)
# Send ready signal to parent process
connection.send({"status": "ready"})
while True:
# Wait for commands from the parent process
try:
command = connection.recv()
except KeyboardInterrupt:
llm.collective_rpc(method="close_communicator")
break
# Handle commands
if command["type"] in ["call", "fire_and_forget"]:
method_name = command["method"]
args, kwargs = command.get("args", ()), command.get("kwargs", {})
method = getattr(llm, method_name)
result = method(*args, **kwargs)
if command["type"] == "call":
connection.send(result)
elif command["type"] == "shutdown":
break
trl.scripts.vllm_serve.llm_worker = llm_worker