fixes for spinning up vllm service for grpo (#3001)
This commit is contained in:
5
.github/workflows/main.yml
vendored
5
.github/workflows/main.yml
vendored
@@ -24,12 +24,13 @@ jobs:
|
|||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
axolotl_extras: vllm
|
axolotl_extras:
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.1
|
pytorch: 2.7.1
|
||||||
axolotl_extras:
|
axolotl_extras: vllm
|
||||||
|
is_latest: true
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
|
|||||||
@@ -2,12 +2,10 @@
|
|||||||
CLI to start the vllm server for online RL
|
CLI to start the vllm server for online RL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import trl
|
|
||||||
from trl.scripts.vllm_serve import ScriptArguments
|
from trl.scripts.vllm_serve import ScriptArguments
|
||||||
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
@@ -42,13 +40,17 @@ def do_vllm_serve(
|
|||||||
|
|
||||||
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
|
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
|
||||||
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
|
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
|
||||||
|
tensor_parallel_size = 1
|
||||||
|
data_parallel_size = 1
|
||||||
|
|
||||||
tensor_parallel_size = (
|
if cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size:
|
||||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
tensor_parallel_size = (
|
||||||
)
|
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||||
data_parallel_size = (
|
)
|
||||||
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
|
if cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size:
|
||||||
)
|
data_parallel_size = (
|
||||||
|
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
|
||||||
|
)
|
||||||
host = cli_args.get("host") or cfg.vllm.host
|
host = cli_args.get("host") or cfg.vllm.host
|
||||||
port = cli_args.get("port") or cfg.vllm.port
|
port = cli_args.get("port") or cfg.vllm.port
|
||||||
gpu_memory_utilization = (
|
gpu_memory_utilization = (
|
||||||
@@ -81,63 +83,3 @@ def do_vllm_serve(
|
|||||||
enable_reasoning=enable_reasoning,
|
enable_reasoning=enable_reasoning,
|
||||||
)
|
)
|
||||||
vllm_serve_main(vllm_script_args)
|
vllm_serve_main(vllm_script_args)
|
||||||
|
|
||||||
|
|
||||||
def patch_vllm_worker():
|
|
||||||
from multiprocessing.connection import Connection
|
|
||||||
|
|
||||||
from vllm import LLM
|
|
||||||
|
|
||||||
def llm_worker(
|
|
||||||
script_args: AxolotlScriptArguments,
|
|
||||||
data_parallel_rank: int,
|
|
||||||
master_port: int,
|
|
||||||
connection: Connection,
|
|
||||||
) -> None:
|
|
||||||
# Set required environment variables for DP to work with vLLM
|
|
||||||
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
|
|
||||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
|
|
||||||
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
|
|
||||||
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
|
||||||
|
|
||||||
llm = LLM(
|
|
||||||
model=script_args.model,
|
|
||||||
revision=script_args.revision,
|
|
||||||
tensor_parallel_size=script_args.tensor_parallel_size,
|
|
||||||
gpu_memory_utilization=script_args.gpu_memory_utilization,
|
|
||||||
enforce_eager=script_args.enforce_eager,
|
|
||||||
dtype=script_args.dtype,
|
|
||||||
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
|
|
||||||
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
|
|
||||||
# This is particularly useful here because we generate completions from the same prompts.
|
|
||||||
enable_prefix_caching=script_args.enable_prefix_caching,
|
|
||||||
kv_cache_dtype=script_args.kv_cache_dtype,
|
|
||||||
max_model_len=script_args.max_model_len,
|
|
||||||
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
|
|
||||||
enable_reasoning=script_args.enable_reasoning,
|
|
||||||
reasoning_parser=script_args.reasoning_parser,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send ready signal to parent process
|
|
||||||
connection.send({"status": "ready"})
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# Wait for commands from the parent process
|
|
||||||
try:
|
|
||||||
command = connection.recv()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
llm.collective_rpc(method="close_communicator")
|
|
||||||
break
|
|
||||||
|
|
||||||
# Handle commands
|
|
||||||
if command["type"] in ["call", "fire_and_forget"]:
|
|
||||||
method_name = command["method"]
|
|
||||||
args, kwargs = command.get("args", ()), command.get("kwargs", {})
|
|
||||||
method = getattr(llm, method_name)
|
|
||||||
result = method(*args, **kwargs)
|
|
||||||
if command["type"] == "call":
|
|
||||||
connection.send(result)
|
|
||||||
elif command["type"] == "shutdown":
|
|
||||||
break
|
|
||||||
|
|
||||||
trl.scripts.vllm_serve.llm_worker = llm_worker
|
|
||||||
|
|||||||
Reference in New Issue
Block a user