* support flattening/packing for GRPO * more flattening * fix tests * improve dead vllm handling * refactor out process handling for vllm serve and move bench flattening tests to gpu tests * add validation for flattening with liger * isolate batch flattening test * flaky test
758 lines
28 KiB
Python
758 lines
28 KiB
Python
"""vLLM serve script with native LoRA adapter support.
|
|
|
|
Extends TRL's vllm_serve to enable direct LoRA adapter loading in vLLM,
|
|
instead of merging adapter weights into the base model before syncing.
|
|
|
|
Usage:
|
|
Set ``vllm.serve_module: axolotl.scripts.vllm_serve_lora`` in your config,
|
|
or ``trl.vllm_lora_sync: true`` to auto-select.
|
|
|
|
Benefits over merge-sync:
|
|
- Syncs only LoRA adapter weights via filesystem instead of full merged model via NCCL
|
|
- vLLM handles LoRA application natively (Punica kernels)
|
|
- No NCCL communicator needed for weight sync
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass, field
|
|
from itertools import chain
|
|
from multiprocessing import Pipe, Process
|
|
from multiprocessing.connection import Connection
|
|
from typing import Any
|
|
|
|
from trl.scripts.vllm_serve import (
|
|
ScriptArguments,
|
|
chunk_list,
|
|
extract_logprobs,
|
|
)
|
|
|
|
try:
|
|
from trl.scripts.vllm_serve import get_open_port
|
|
except ImportError:
|
|
try:
|
|
from vllm.utils import get_open_port
|
|
except ImportError:
|
|
from vllm.utils.network_utils import get_open_port
|
|
from vllm import LLM, SamplingParams
|
|
from vllm.lora.request import LoRARequest
|
|
|
|
from axolotl.scripts.process_cleanup import (
|
|
ProcessManager,
|
|
is_fatal_worker_error,
|
|
safe_recv,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class LoRAScriptArguments(ScriptArguments):
|
|
"""Extended script arguments with LoRA support."""
|
|
|
|
enable_lora: bool = field(
|
|
default=True,
|
|
metadata={"help": "Enable LoRA adapter support in vLLM."},
|
|
)
|
|
max_lora_rank: int = field(
|
|
default=64,
|
|
metadata={"help": "Maximum LoRA rank supported."},
|
|
)
|
|
max_loras: int = field(
|
|
default=2,
|
|
metadata={"help": "Maximum number of LoRA adapters loaded simultaneously."},
|
|
)
|
|
lora_dtype: str = field(
|
|
default="bfloat16",
|
|
metadata={"help": "Data type for LoRA weights."},
|
|
)
|
|
worker_extension_cls: str = field(
|
|
default="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
|
|
metadata={"help": "vLLM worker extension class for weight synchronization."},
|
|
)
|
|
|
|
|
|
def llm_worker(
|
|
script_args: LoRAScriptArguments,
|
|
data_parallel_rank: int,
|
|
master_port: int,
|
|
connection: Connection,
|
|
) -> None:
|
|
"""Worker process that creates a vLLM LLM with LoRA enabled."""
|
|
# For DP with TP=1: pin each worker to its own GPU via CUDA_VISIBLE_DEVICES.
|
|
# vLLM's LLM() offline mode doesn't support DP env vars natively, so we
|
|
# isolate each worker to a single GPU and let vLLM think it's the only one.
|
|
if script_args.data_parallel_size > 1 and script_args.tensor_parallel_size == 1:
|
|
visible = os.environ.get("CUDA_VISIBLE_DEVICES", "")
|
|
if visible:
|
|
gpu_ids = visible.split(",")
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[data_parallel_rank]
|
|
else:
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(data_parallel_rank)
|
|
else:
|
|
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
|
|
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
|
|
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
|
|
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
|
|
|
llm = LLM(
|
|
model=script_args.model,
|
|
revision=script_args.revision,
|
|
tensor_parallel_size=script_args.tensor_parallel_size,
|
|
gpu_memory_utilization=script_args.gpu_memory_utilization,
|
|
enforce_eager=script_args.enforce_eager,
|
|
dtype=script_args.dtype,
|
|
enable_prefix_caching=script_args.enable_prefix_caching,
|
|
kv_cache_dtype=script_args.kv_cache_dtype,
|
|
max_model_len=script_args.max_model_len,
|
|
worker_extension_cls=script_args.worker_extension_cls,
|
|
trust_remote_code=script_args.trust_remote_code,
|
|
model_impl=script_args.vllm_model_impl,
|
|
logprobs_mode="processed_logprobs",
|
|
# LoRA
|
|
enable_lora=script_args.enable_lora,
|
|
max_lora_rank=script_args.max_lora_rank,
|
|
max_loras=script_args.max_loras,
|
|
lora_dtype=script_args.lora_dtype,
|
|
)
|
|
|
|
connection.send({"status": "ready"})
|
|
|
|
def _worker_cleanup():
|
|
"""Clean up the LLM and its EngineCore subprocess on worker exit."""
|
|
from axolotl.scripts.process_cleanup import cleanup_orphan_processes
|
|
|
|
try:
|
|
llm.collective_rpc(method="close_communicator")
|
|
except Exception:
|
|
pass
|
|
# Kill EngineCore children of this worker
|
|
cleanup_orphan_processes("VLLM::EngineCore")
|
|
|
|
import atexit as _atexit
|
|
|
|
_atexit.register(_worker_cleanup)
|
|
|
|
while True:
|
|
try:
|
|
command = connection.recv()
|
|
except (KeyboardInterrupt, EOFError):
|
|
break
|
|
|
|
if command.get("type") == "shutdown":
|
|
break
|
|
|
|
if command["type"] in ["call", "fire_and_forget"]:
|
|
method_name = command["method"]
|
|
args = command.get("args", ())
|
|
kwargs = command.get("kwargs", {})
|
|
|
|
# Reconstruct LoRARequest from serialized dict (can't pickle across pipe)
|
|
if "lora_request" in kwargs and kwargs["lora_request"] is not None:
|
|
lr = kwargs["lora_request"]
|
|
kwargs["lora_request"] = LoRARequest(
|
|
lora_name=lr["lora_name"],
|
|
lora_int_id=lr["lora_int_id"],
|
|
lora_path=lr["lora_path"],
|
|
load_inplace=lr.get("load_inplace", False),
|
|
)
|
|
|
|
try:
|
|
method = getattr(llm, method_name)
|
|
result = method(*args, **kwargs)
|
|
except Exception as exc:
|
|
logger.warning("Worker method %s failed: %s", method_name, exc)
|
|
if command["type"] == "call":
|
|
connection.send({"error": str(exc), "kind": "worker_error"})
|
|
if is_fatal_worker_error(exc):
|
|
logger.error(
|
|
"Fatal worker error (EngineCore died), exiting. "
|
|
"Restart the vLLM server to recover."
|
|
)
|
|
break
|
|
continue
|
|
if command["type"] == "call":
|
|
connection.send(result)
|
|
elif command["type"] == "shutdown":
|
|
break
|
|
|
|
|
|
def main(script_args: ScriptArguments):
|
|
"""Start vLLM workers with LoRA support and the HTTP server."""
|
|
import asyncio
|
|
|
|
import uvicorn
|
|
from fastapi import FastAPI
|
|
from pydantic import BaseModel, Field as PydanticField
|
|
|
|
# Request/Response models (defined locally like TRL's vllm_serve.main)
|
|
class GenerateRequest(BaseModel):
|
|
prompts: list[str] | list[list[int]]
|
|
images: list[str] | None = None
|
|
n: int = 1
|
|
repetition_penalty: float = 1.0
|
|
temperature: float = 1.0
|
|
top_p: float = 1.0
|
|
top_k: int = -1
|
|
min_p: float = 0.0
|
|
max_tokens: int = 16
|
|
logprobs: int | None = 0
|
|
truncate_prompt_tokens: int | None = None
|
|
structured_outputs_regex: str | None = None
|
|
generation_kwargs: dict = PydanticField(default_factory=dict)
|
|
|
|
class GenerateResponse(BaseModel):
|
|
prompt_ids: list[list[int]]
|
|
completion_ids: list[list[int]]
|
|
logprobs: list[list[list[float]]]
|
|
logprob_token_ids: list[list[list[int]]]
|
|
|
|
class ChatRequest(BaseModel):
|
|
messages: list[list[dict]]
|
|
n: int = 1
|
|
repetition_penalty: float = 1.0
|
|
temperature: float = 1.0
|
|
top_p: float = 1.0
|
|
top_k: int = -1
|
|
min_p: float = 0.0
|
|
max_tokens: int = 16
|
|
logprobs: int | None = 0
|
|
truncate_prompt_tokens: int | None = None
|
|
structured_outputs_regex: str | None = None
|
|
generation_kwargs: dict = PydanticField(default_factory=dict)
|
|
chat_template_kwargs: dict = PydanticField(default_factory=dict)
|
|
|
|
class ChatResponse(BaseModel):
|
|
prompt_ids: list[list[int]]
|
|
completion_ids: list[list[int]]
|
|
logprobs: list[list[list[float]]]
|
|
logprob_token_ids: list[list[list[int]]]
|
|
|
|
class InitCommunicatorRequest(BaseModel):
|
|
host: str
|
|
port: int
|
|
world_size: int
|
|
client_device_uuid: str
|
|
|
|
# Wrap plain ScriptArguments with LoRA defaults
|
|
if not isinstance(script_args, LoRAScriptArguments):
|
|
lora_args = LoRAScriptArguments.__new__(LoRAScriptArguments)
|
|
for f in ScriptArguments.__dataclass_fields__:
|
|
setattr(lora_args, f, getattr(script_args, f))
|
|
# Apply LoRA defaults
|
|
for f in LoRAScriptArguments.__dataclass_fields__:
|
|
if f not in ScriptArguments.__dataclass_fields__:
|
|
setattr(
|
|
lora_args, f, LoRAScriptArguments.__dataclass_fields__[f].default
|
|
)
|
|
script_args = lora_args
|
|
|
|
# Spawn workers
|
|
master_port = get_open_port()
|
|
connections: list[Connection] = []
|
|
processes: list[Process] = []
|
|
for dp_rank in range(script_args.data_parallel_size):
|
|
parent_conn, child_conn = Pipe()
|
|
process = Process(
|
|
target=llm_worker,
|
|
args=(script_args, dp_rank, master_port, child_conn),
|
|
)
|
|
process.start()
|
|
connections.append(parent_conn)
|
|
processes.append(process)
|
|
|
|
# Process lifecycle management
|
|
manager = ProcessManager(processes, connections)
|
|
manager.register_cleanup()
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
import time
|
|
|
|
startup_timeout = 300 # 5 minutes
|
|
start_time = time.monotonic()
|
|
ready: set[int] = set()
|
|
while len(ready) < script_args.data_parallel_size:
|
|
elapsed = time.monotonic() - start_time
|
|
if elapsed > startup_timeout:
|
|
raise RuntimeError(
|
|
f"vLLM workers failed to start within {startup_timeout}s "
|
|
f"({len(ready)}/{script_args.data_parallel_size} ready)"
|
|
)
|
|
for i, (conn, proc) in enumerate(zip(connections, processes, strict=True)):
|
|
if id(conn) in ready:
|
|
continue
|
|
if not proc.is_alive():
|
|
raise RuntimeError(
|
|
f"vLLM worker {i} exited unexpectedly during startup"
|
|
)
|
|
if conn.poll():
|
|
msg = conn.recv()
|
|
if isinstance(msg, dict) and msg.get("status") == "ready":
|
|
ready.add(id(conn))
|
|
await asyncio.sleep(0.1)
|
|
|
|
monitor_task = asyncio.create_task(manager.monitor_workers())
|
|
yield
|
|
monitor_task.cancel()
|
|
manager._shutdown_workers()
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
# --- Access logging middleware ---
|
|
import time as _time
|
|
|
|
@app.middleware("http")
|
|
async def access_log_middleware(request, call_next):
|
|
t0 = _time.monotonic()
|
|
response = await call_next(request)
|
|
elapsed = _time.monotonic() - t0
|
|
logger.info(
|
|
"%s %s %d %.3fs",
|
|
request.method,
|
|
request.url.path,
|
|
response.status_code,
|
|
elapsed,
|
|
)
|
|
return response
|
|
|
|
# --- Active LoRA state (shared across endpoints via closure) ---
|
|
active_lora: dict = {"request": None}
|
|
|
|
# ------------------------------------------------------------------
|
|
# LoRA-specific endpoints
|
|
# ------------------------------------------------------------------
|
|
|
|
class SetLoRARequest(BaseModel):
|
|
lora_name: str
|
|
lora_int_id: int
|
|
lora_path: str
|
|
load_inplace: bool = False
|
|
|
|
@app.post("/set_lora_adapter/")
|
|
async def set_lora_adapter(request: SetLoRARequest):
|
|
"""Register a LoRA adapter for all subsequent generate/chat calls."""
|
|
active_lora["request"] = {
|
|
"lora_name": request.lora_name,
|
|
"lora_int_id": request.lora_int_id,
|
|
"lora_path": request.lora_path,
|
|
"load_inplace": request.load_inplace,
|
|
}
|
|
logger.info(
|
|
"Set active LoRA: %s (id=%d, path=%s)",
|
|
request.lora_name,
|
|
request.lora_int_id,
|
|
request.lora_path,
|
|
)
|
|
return {"status": "ok"}
|
|
|
|
@app.post("/clear_lora_adapter/")
|
|
async def clear_lora_adapter():
|
|
"""Clear active LoRA adapter (revert to base model)."""
|
|
active_lora["request"] = None
|
|
return {"status": "ok"}
|
|
|
|
# ------------------------------------------------------------------
|
|
# Standard endpoints (mirrors TRL's vllm_serve)
|
|
# ------------------------------------------------------------------
|
|
|
|
@app.get("/health/")
|
|
async def health():
|
|
status = manager.get_health_status()
|
|
if status["status"] != "ok":
|
|
from fastapi.responses import JSONResponse
|
|
|
|
return JSONResponse(status_code=503, content=status)
|
|
return status
|
|
|
|
@app.get("/get_world_size/")
|
|
async def get_world_size():
|
|
return {
|
|
"world_size": script_args.tensor_parallel_size
|
|
* script_args.data_parallel_size
|
|
}
|
|
|
|
@app.post("/generate/", response_model=GenerateResponse)
|
|
async def generate(request: GenerateRequest):
|
|
"""Generate completions with optional LoRA adapter."""
|
|
manager.check_workers_alive()
|
|
|
|
import base64
|
|
from io import BytesIO
|
|
|
|
import vllm
|
|
from packaging.version import Version
|
|
|
|
try:
|
|
from vllm.sampling_params import GuidedDecodingParams
|
|
except ImportError:
|
|
GuidedDecodingParams = None # not available in vLLM 0.17+
|
|
|
|
images: list[str | None] = request.images or [None] * len(request.prompts) # type: ignore[assignment,list-item]
|
|
prompts: list[dict[str, Any]] = []
|
|
for prompt, image in zip(request.prompts, images, strict=True):
|
|
# Support both string prompts and token ID lists
|
|
row: dict[str, Any]
|
|
if isinstance(prompt, list):
|
|
row = {"prompt_token_ids": prompt}
|
|
else:
|
|
row = {"prompt": prompt}
|
|
if image is not None:
|
|
from PIL import Image
|
|
|
|
row["multi_modal_data"] = {
|
|
"image": Image.open(BytesIO(base64.b64decode(image)))
|
|
}
|
|
prompts.append(row)
|
|
|
|
generation_kwargs = {
|
|
"n": request.n,
|
|
"repetition_penalty": request.repetition_penalty,
|
|
"temperature": request.temperature,
|
|
"top_p": request.top_p,
|
|
"top_k": request.top_k,
|
|
"min_p": request.min_p,
|
|
"max_tokens": request.max_tokens,
|
|
"logprobs": request.logprobs,
|
|
}
|
|
generation_kwargs.update(request.generation_kwargs)
|
|
|
|
if Version(vllm.__version__) <= Version("0.10.2"):
|
|
key = "guided_decoding"
|
|
if request.structured_outputs_regex is not None:
|
|
generation_kwargs[key] = GuidedDecodingParams(
|
|
regex=request.structured_outputs_regex
|
|
)
|
|
else:
|
|
generation_kwargs.setdefault(key, None)
|
|
else:
|
|
from vllm.sampling_params import StructuredOutputsParams
|
|
|
|
key = "structured_outputs"
|
|
if request.structured_outputs_regex is not None:
|
|
generation_kwargs[key] = StructuredOutputsParams(
|
|
regex=request.structured_outputs_regex
|
|
)
|
|
elif isinstance(generation_kwargs.get(key), dict):
|
|
generation_kwargs[key] = StructuredOutputsParams(
|
|
**generation_kwargs[key]
|
|
)
|
|
else:
|
|
generation_kwargs.setdefault(key, None)
|
|
|
|
sampling_params = SamplingParams(**generation_kwargs)
|
|
chunked_prompts = chunk_list(prompts, script_args.data_parallel_size)
|
|
|
|
for conn, chunk in zip(connections, chunked_prompts, strict=True):
|
|
if not chunk:
|
|
chunk = [{"prompt": "<placeholder>"}]
|
|
kwargs = {
|
|
"prompts": chunk,
|
|
"sampling_params": sampling_params,
|
|
"lora_request": active_lora["request"],
|
|
}
|
|
conn.send({"type": "call", "method": "generate", "kwargs": kwargs})
|
|
|
|
# Use run_in_executor so blocking recv() doesn't freeze the event loop
|
|
# (allows /set_lora_adapter/ and other endpoints to be served concurrently)
|
|
loop = asyncio.get_running_loop()
|
|
|
|
all_outputs = await asyncio.gather(
|
|
*(loop.run_in_executor(None, safe_recv, conn) for conn in connections)
|
|
)
|
|
all_outputs = [
|
|
o for o, c in zip(all_outputs, chunked_prompts, strict=True) if c
|
|
]
|
|
# Check for worker errors before flattening
|
|
for o in all_outputs:
|
|
if isinstance(o, dict) and "error" in o:
|
|
raise RuntimeError(f"vLLM worker error: {o['error']}")
|
|
all_outputs = list(chain.from_iterable(all_outputs))
|
|
|
|
return {
|
|
"prompt_ids": [o.prompt_token_ids for o in all_outputs],
|
|
"completion_ids": [
|
|
list(out.token_ids) for o in all_outputs for out in o.outputs
|
|
],
|
|
"logprobs": extract_logprobs(all_outputs)[0],
|
|
"logprob_token_ids": extract_logprobs(all_outputs)[1],
|
|
}
|
|
|
|
@app.post("/chat/", response_model=ChatResponse)
|
|
async def chat(request: ChatRequest):
|
|
"""Chat endpoint with optional LoRA adapter."""
|
|
manager.check_workers_alive()
|
|
generation_kwargs = {
|
|
"n": request.n,
|
|
"repetition_penalty": request.repetition_penalty,
|
|
"temperature": request.temperature,
|
|
"top_p": request.top_p,
|
|
"top_k": request.top_k,
|
|
"min_p": request.min_p,
|
|
"max_tokens": request.max_tokens,
|
|
"logprobs": request.logprobs,
|
|
}
|
|
generation_kwargs.update(request.generation_kwargs)
|
|
sampling_params = SamplingParams(**generation_kwargs)
|
|
chunked = chunk_list(request.messages, script_args.data_parallel_size)
|
|
for conn, chunk in zip(connections, chunked, strict=True):
|
|
if not chunk:
|
|
chunk = [[{"role": "user", "content": "<placeholder>"}]]
|
|
kwargs = {
|
|
"messages": chunk,
|
|
"sampling_params": sampling_params,
|
|
"use_tqdm": False,
|
|
"lora_request": active_lora["request"],
|
|
}
|
|
conn.send({"type": "call", "method": "chat", "kwargs": kwargs})
|
|
|
|
loop = asyncio.get_running_loop()
|
|
all_outputs = await asyncio.gather(
|
|
*(loop.run_in_executor(None, conn.recv) for conn in connections)
|
|
)
|
|
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
|
|
all_outputs = list(chain.from_iterable(all_outputs))
|
|
|
|
return {
|
|
"prompt_ids": [o.prompt_token_ids for o in all_outputs],
|
|
"completion_ids": [
|
|
list(out.token_ids) for o in all_outputs for out in o.outputs
|
|
],
|
|
"logprobs": extract_logprobs(all_outputs)[0],
|
|
"logprob_token_ids": extract_logprobs(all_outputs)[1],
|
|
}
|
|
|
|
# --- OpenAI-compatible endpoints (for NeMo Gym agent integration) ---
|
|
|
|
@app.get("/v1/models")
|
|
async def list_models():
|
|
"""OpenAI-compatible models endpoint."""
|
|
return {
|
|
"object": "list",
|
|
"data": [
|
|
{"id": script_args.model, "object": "model", "owned_by": "axolotl"}
|
|
],
|
|
}
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def openai_chat_completions(request_body: dict):
|
|
"""OpenAI-compatible chat completions endpoint.
|
|
|
|
Translates OpenAI format to our internal /chat/ format so NeMo Gym's
|
|
model server proxy can call us directly.
|
|
"""
|
|
messages_list = request_body.get("messages", [])
|
|
temperature = request_body.get("temperature", 1.0)
|
|
max_tokens = request_body.get("max_tokens", 512)
|
|
top_p = request_body.get("top_p", 1.0)
|
|
n = request_body.get("n", 1)
|
|
|
|
generation_kwargs = {
|
|
"n": n,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"max_tokens": max_tokens,
|
|
"logprobs": 0, # Always return logprobs (NeMo Gym needs them)
|
|
}
|
|
sampling_params = SamplingParams(
|
|
**{k: v for k, v in generation_kwargs.items() if v is not None}
|
|
)
|
|
|
|
# Send to vLLM worker
|
|
chunked = chunk_list([messages_list], script_args.data_parallel_size)
|
|
for conn, chunk in zip(connections, chunked, strict=True):
|
|
if not chunk:
|
|
chunk = [[{"role": "user", "content": "<placeholder>"}]]
|
|
kwargs = {
|
|
"messages": chunk,
|
|
"sampling_params": sampling_params,
|
|
"use_tqdm": False,
|
|
"lora_request": active_lora["request"],
|
|
}
|
|
conn.send({"type": "call", "method": "chat", "kwargs": kwargs})
|
|
|
|
all_outputs = [conn.recv() for conn in connections]
|
|
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
|
|
all_outputs = list(chain.from_iterable(all_outputs))
|
|
|
|
if not all_outputs:
|
|
return {"choices": [], "model": script_args.model}
|
|
|
|
# Format as OpenAI response
|
|
import uuid
|
|
|
|
choices = []
|
|
for i, output in enumerate(all_outputs):
|
|
for j, out in enumerate(output.outputs):
|
|
text = out.text
|
|
# Extract token IDs if requested
|
|
# Build logprobs in OpenAI format
|
|
lp_list = None
|
|
if out.logprobs:
|
|
lp_list = {
|
|
"content": [
|
|
{"token": "", "logprob": next(iter(lp.values())).logprob} # nosec B105
|
|
for lp in out.logprobs
|
|
]
|
|
}
|
|
|
|
choice = {
|
|
"index": i * n + j,
|
|
"message": {"role": "assistant", "content": text},
|
|
"finish_reason": "stop"
|
|
if out.finish_reason == "stop"
|
|
else "length",
|
|
"logprobs": lp_list,
|
|
}
|
|
# Include token ID information for NeMo Gym
|
|
choice["prompt_token_ids"] = output.prompt_token_ids
|
|
choice["generation_token_ids"] = list(out.token_ids)
|
|
if out.logprobs:
|
|
choice["generation_log_probs"] = [
|
|
next(iter(lp.values())).logprob for lp in out.logprobs
|
|
]
|
|
choices.append(choice)
|
|
|
|
prompt_tokens = len(all_outputs[0].prompt_token_ids) if all_outputs else 0
|
|
completion_tokens = sum(
|
|
len(out.token_ids) for o in all_outputs for out in o.outputs
|
|
)
|
|
|
|
return {
|
|
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
|
"object": "chat.completion",
|
|
"model": script_args.model,
|
|
"choices": choices,
|
|
"usage": {
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": completion_tokens,
|
|
"total_tokens": prompt_tokens + completion_tokens,
|
|
},
|
|
}
|
|
|
|
# --- Weight sync endpoints (legacy fallback, same as TRL) ---
|
|
|
|
@app.post("/init_communicator/")
|
|
async def init_communicator(request: InitCommunicatorRequest):
|
|
world_size = (
|
|
script_args.tensor_parallel_size * script_args.data_parallel_size + 1
|
|
)
|
|
kwargs = {
|
|
"method": "init_communicator",
|
|
"args": (
|
|
request.host,
|
|
request.port,
|
|
world_size,
|
|
request.client_device_uuid,
|
|
),
|
|
}
|
|
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
|
|
loop = asyncio.get_running_loop()
|
|
await asyncio.gather(
|
|
*(loop.run_in_executor(None, c.send, msg) for c in connections)
|
|
)
|
|
return {"message": "Initializing communicator"}
|
|
|
|
class UpdateWeightsRequest(BaseModel):
|
|
name: str
|
|
dtype: str
|
|
shape: list[int]
|
|
|
|
@app.post("/update_named_param/")
|
|
async def update_named_param(request: UpdateWeightsRequest):
|
|
kwargs = {
|
|
"method": "update_named_param",
|
|
"args": (request.name, request.dtype, tuple(request.shape)),
|
|
}
|
|
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
|
|
loop = asyncio.get_running_loop()
|
|
await asyncio.gather(
|
|
*(loop.run_in_executor(None, c.send, msg) for c in connections)
|
|
)
|
|
return {"message": "Updating parameter"}
|
|
|
|
class BatchUpdateWeightsRequest(BaseModel):
|
|
params: list[dict]
|
|
|
|
@app.post("/batch_update_named_params/")
|
|
async def batch_update_named_params(request: BatchUpdateWeightsRequest):
|
|
params_list = [
|
|
(p["name"], p["dtype"], tuple(p["shape"])) for p in request.params
|
|
]
|
|
kwargs = {"method": "batch_update_named_params", "args": (params_list,)}
|
|
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
|
|
loop = asyncio.get_running_loop()
|
|
await asyncio.gather(
|
|
*(loop.run_in_executor(None, c.send, msg) for c in connections)
|
|
)
|
|
return {"message": f"Batch update for {len(params_list)} params"}
|
|
|
|
class HTTPWeightUpdateRequest(BaseModel):
|
|
"""Weight update via HTTP (no NCCL needed)."""
|
|
|
|
params: list[
|
|
dict
|
|
] # [{"name": str, "dtype": str, "shape": list, "data": str (base64)}]
|
|
|
|
@app.post("/http_update_weights/")
|
|
async def http_update_weights(request: HTTPWeightUpdateRequest):
|
|
"""Update model weights via HTTP — no NCCL communicator required.
|
|
|
|
Tensor data is sent as base64-encoded raw bytes in the request body.
|
|
Slower than NCCL for large models but works without cross-process setup.
|
|
"""
|
|
from axolotl.utils.weight_serde import (
|
|
decode_from_http,
|
|
encode_for_ipc,
|
|
)
|
|
|
|
weights_to_load = [decode_from_http(p) for p in request.params]
|
|
|
|
# Send all weights in a single IPC call. Tensors don't survive
|
|
# vLLM's multiproc IPC, so serialize as raw bytes + metadata.
|
|
param_entries = [
|
|
encode_for_ipc(name, weight) for name, weight in weights_to_load
|
|
]
|
|
kwargs = {
|
|
"method": "http_load_weights_batch",
|
|
"kwargs": {"params": param_entries},
|
|
}
|
|
msg = {"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}
|
|
loop = asyncio.get_running_loop()
|
|
await asyncio.gather(
|
|
*(loop.run_in_executor(None, c.send, msg) for c in connections)
|
|
)
|
|
return {"message": f"HTTP weight update for {len(weights_to_load)} params"}
|
|
|
|
@app.post("/reset_prefix_cache/")
|
|
async def reset_prefix_cache():
|
|
# Fire-and-forget: send reset without expecting a reply.
|
|
# Using "fire_and_forget" type so workers don't send back a response
|
|
# that would sit in the pipe and corrupt the next recv() for
|
|
# generate/chat calls.
|
|
for conn in connections:
|
|
conn.send({"type": "fire_and_forget", "method": "reset_prefix_cache"})
|
|
return {"message": "Reset prefix cache received"}
|
|
|
|
@app.post("/close_communicator/")
|
|
async def close_communicator():
|
|
kwargs = {"method": "close_communicator"}
|
|
for conn in connections:
|
|
conn.send(
|
|
{
|
|
"type": "fire_and_forget",
|
|
"method": "collective_rpc",
|
|
"kwargs": kwargs,
|
|
}
|
|
)
|
|
return {"message": "Closing communicator"}
|
|
|
|
uvicorn.run(
|
|
app,
|
|
host=script_args.host,
|
|
port=script_args.port,
|
|
log_level=script_args.log_level,
|
|
access_log=True,
|
|
)
|