vllm-serve-lora add /v1/completions route + worker pipe lock

The LoRA vllm-serve wrapper only exposed /v1/chat/completions, but
retrace's SWE agent server uses the token-id-aware /v1/completions
endpoint so it can feed raw prompt_token_ids + track per-token
logprobs across multi-turn rollouts. Add the route, mirroring the
shape of /v1/chat/completions but routing to the vLLM worker's
generate() method so prompt_token_ids are passed through as-is.

Also add a worker_pipe_lock around conn.send/conn.recv. The
multiprocessing.Connection to the vLLM worker is a single shared
full-duplex pipe; concurrent HTTP requests interleave pickle frames
on the wire and corrupt the stream (observed as
UnpicklingError: pickle data was truncated, surfacing as 500s).
The agent server fires ~8 concurrent rollout requests at once, so
this was a hard blocker for any multi-concurrent workload. Serialize
access to the pipe per-request round-trip.
This commit is contained in:
Wing Lian
2026-04-14 15:52:02 +00:00
parent 7617b951a8
commit 53391a10d7

View File

@@ -320,6 +320,15 @@ def main(script_args: ScriptArguments):
# --- Active LoRA state (shared across endpoints via closure) ---
active_lora: dict = {"request": None}
# Serializes access to the worker pipe. The underlying
# multiprocessing.Connection is a single full-duplex stream shared
# across all HTTP handlers; concurrent requests interleave bytes on
# the wire and corrupt the pickle framing (seen as
# ``UnpicklingError: pickle data was truncated``). Any endpoint that
# does ``conn.send(...); conn.recv()`` MUST hold this lock across
# the round-trip so only one inflight call at a time per pipe.
worker_pipe_lock = asyncio.Lock()
# ------------------------------------------------------------------
# LoRA-specific endpoints
# ------------------------------------------------------------------
@@ -631,6 +640,147 @@ def main(script_args: ScriptArguments):
},
}
@app.post("/v1/completions")
async def openai_completions(request_body: dict):
"""OpenAI-compatible text-completions endpoint.
Accepts either a string ``prompt`` or a list-of-int
``prompt_token_ids`` (as the text-completions spec allows). Routes
to the internal vLLM generate method with the active LoRA adapter
and returns an OpenAI /v1/completions-shaped response including
per-choice ``prompt_token_ids``, ``generation_token_ids``, and
``generation_log_probs`` for NeMo Gym agents that need raw
tokens + logprobs.
"""
import uuid
prompt_raw = request_body.get("prompt")
temperature = request_body.get("temperature", 1.0)
max_tokens = request_body.get("max_tokens", 512)
top_p = request_body.get("top_p", 1.0)
n = request_body.get("n", 1)
logprobs = request_body.get("logprobs") or 0
stop_token_ids = request_body.get("stop_token_ids") or None
# Accept either a string or a list[int] token id prompt. Lists
# must contain ints only (raise on lists of strings so callers get
# a clear error). Also accept [[int, int, ...]] nesting for the
# rare case callers pass a single-prompt batch.
if isinstance(prompt_raw, list) and prompt_raw and isinstance(prompt_raw[0], list):
prompt_raw = prompt_raw[0]
if isinstance(prompt_raw, list):
prompt_dict = {"prompt_token_ids": prompt_raw}
elif isinstance(prompt_raw, str):
prompt_dict = {"prompt": prompt_raw}
else:
return {
"error": {
"message": (
"prompt must be a string or a list of token ids"
),
"type": "invalid_request",
}
}
generation_kwargs: dict[str, Any] = {
"n": n,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"logprobs": logprobs,
}
if stop_token_ids:
generation_kwargs["stop_token_ids"] = stop_token_ids
sampling_params = SamplingParams(
**{k: v for k, v in generation_kwargs.items() if v is not None}
)
chunked = chunk_list([prompt_dict], script_args.data_parallel_size)
# Hold the pipe lock across send+recv — concurrent requests would
# otherwise interleave pickle frames on the worker connection.
async with worker_pipe_lock:
for conn, chunk in zip(connections, chunked, strict=True):
if not chunk:
chunk = [{"prompt": "<placeholder>"}]
kwargs = {
"prompts": chunk,
"sampling_params": sampling_params,
"lora_request": active_lora["request"],
}
conn.send({"type": "call", "method": "generate", "kwargs": kwargs})
loop = asyncio.get_running_loop()
all_outputs = await asyncio.gather(
*(loop.run_in_executor(None, safe_recv, conn) for conn in connections)
)
all_outputs = [
o for o, c in zip(all_outputs, chunked, strict=True) if c
]
for o in all_outputs:
if isinstance(o, dict) and "error" in o:
raise RuntimeError(f"vLLM worker error: {o['error']}")
all_outputs = list(chain.from_iterable(all_outputs))
if not all_outputs:
return {"choices": [], "model": script_args.model}
choices = []
for i, output in enumerate(all_outputs):
for j, out in enumerate(output.outputs):
text = out.text
# OpenAI-style `logprobs` block for text-completions:
# { "tokens": [...], "token_logprobs": [...] }
lp_block = None
if out.logprobs:
tokens_str: list[str] = []
token_lps: list[float] = []
for step in out.logprobs:
chosen = next(iter(step.values()))
tokens_str.append(getattr(chosen, "decoded_token", "") or "")
token_lps.append(float(chosen.logprob))
lp_block = {
"tokens": tokens_str,
"token_logprobs": token_lps,
}
choice = {
"index": i * n + j,
"text": text,
"finish_reason": "stop" if out.finish_reason == "stop" else "length",
"logprobs": lp_block,
# NeMo-Gym / retrace agent extras — preserved on the
# choice so callers with raw-token pipelines don't
# have to re-tokenize.
"prompt_token_ids": output.prompt_token_ids,
"generation_token_ids": list(out.token_ids),
"generation_log_probs": (
[float(next(iter(lp.values())).logprob) for lp in out.logprobs]
if out.logprobs
else []
),
}
choices.append(choice)
prompt_tokens = len(all_outputs[0].prompt_token_ids) if all_outputs else 0
completion_tokens = sum(
len(out.token_ids) for o in all_outputs for out in o.outputs
)
return {
"id": f"cmpl-{uuid.uuid4().hex[:8]}",
"object": "text_completion",
"model": script_args.model,
"choices": choices,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
# --- Weight sync endpoints (legacy fallback, same as TRL) ---
@app.post("/init_communicator/")