vllm-serve-lora add /v1/completions route + worker pipe lock
The LoRA vllm-serve wrapper only exposed /v1/chat/completions, but retrace's SWE agent server uses the token-id-aware /v1/completions endpoint so it can feed raw prompt_token_ids + track per-token logprobs across multi-turn rollouts. Add the route, mirroring the shape of /v1/chat/completions but routing to the vLLM worker's generate() method so prompt_token_ids are passed through as-is. Also add a worker_pipe_lock around conn.send/conn.recv. The multiprocessing.Connection to the vLLM worker is a single shared full-duplex pipe; concurrent HTTP requests interleave pickle frames on the wire and corrupt the stream (observed as UnpicklingError: pickle data was truncated, surfacing as 500s). The agent server fires ~8 concurrent rollout requests at once, so this was a hard blocker for any multi-concurrent workload. Serialize access to the pipe per-request round-trip.
This commit is contained in:
@@ -320,6 +320,15 @@ def main(script_args: ScriptArguments):
|
||||
# --- Active LoRA state (shared across endpoints via closure) ---
|
||||
active_lora: dict = {"request": None}
|
||||
|
||||
# Serializes access to the worker pipe. The underlying
|
||||
# multiprocessing.Connection is a single full-duplex stream shared
|
||||
# across all HTTP handlers; concurrent requests interleave bytes on
|
||||
# the wire and corrupt the pickle framing (seen as
|
||||
# ``UnpicklingError: pickle data was truncated``). Any endpoint that
|
||||
# does ``conn.send(...); conn.recv()`` MUST hold this lock across
|
||||
# the round-trip so only one inflight call at a time per pipe.
|
||||
worker_pipe_lock = asyncio.Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# LoRA-specific endpoints
|
||||
# ------------------------------------------------------------------
|
||||
@@ -631,6 +640,147 @@ def main(script_args: ScriptArguments):
|
||||
},
|
||||
}
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def openai_completions(request_body: dict):
|
||||
"""OpenAI-compatible text-completions endpoint.
|
||||
|
||||
Accepts either a string ``prompt`` or a list-of-int
|
||||
``prompt_token_ids`` (as the text-completions spec allows). Routes
|
||||
to the internal vLLM generate method with the active LoRA adapter
|
||||
and returns an OpenAI /v1/completions-shaped response including
|
||||
per-choice ``prompt_token_ids``, ``generation_token_ids``, and
|
||||
``generation_log_probs`` for NeMo Gym agents that need raw
|
||||
tokens + logprobs.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
prompt_raw = request_body.get("prompt")
|
||||
temperature = request_body.get("temperature", 1.0)
|
||||
max_tokens = request_body.get("max_tokens", 512)
|
||||
top_p = request_body.get("top_p", 1.0)
|
||||
n = request_body.get("n", 1)
|
||||
logprobs = request_body.get("logprobs") or 0
|
||||
stop_token_ids = request_body.get("stop_token_ids") or None
|
||||
|
||||
# Accept either a string or a list[int] token id prompt. Lists
|
||||
# must contain ints only (raise on lists of strings so callers get
|
||||
# a clear error). Also accept [[int, int, ...]] nesting for the
|
||||
# rare case callers pass a single-prompt batch.
|
||||
if isinstance(prompt_raw, list) and prompt_raw and isinstance(prompt_raw[0], list):
|
||||
prompt_raw = prompt_raw[0]
|
||||
|
||||
if isinstance(prompt_raw, list):
|
||||
prompt_dict = {"prompt_token_ids": prompt_raw}
|
||||
elif isinstance(prompt_raw, str):
|
||||
prompt_dict = {"prompt": prompt_raw}
|
||||
else:
|
||||
return {
|
||||
"error": {
|
||||
"message": (
|
||||
"prompt must be a string or a list of token ids"
|
||||
),
|
||||
"type": "invalid_request",
|
||||
}
|
||||
}
|
||||
|
||||
generation_kwargs: dict[str, Any] = {
|
||||
"n": n,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"max_tokens": max_tokens,
|
||||
"logprobs": logprobs,
|
||||
}
|
||||
if stop_token_ids:
|
||||
generation_kwargs["stop_token_ids"] = stop_token_ids
|
||||
sampling_params = SamplingParams(
|
||||
**{k: v for k, v in generation_kwargs.items() if v is not None}
|
||||
)
|
||||
|
||||
chunked = chunk_list([prompt_dict], script_args.data_parallel_size)
|
||||
|
||||
# Hold the pipe lock across send+recv — concurrent requests would
|
||||
# otherwise interleave pickle frames on the worker connection.
|
||||
async with worker_pipe_lock:
|
||||
for conn, chunk in zip(connections, chunked, strict=True):
|
||||
if not chunk:
|
||||
chunk = [{"prompt": "<placeholder>"}]
|
||||
kwargs = {
|
||||
"prompts": chunk,
|
||||
"sampling_params": sampling_params,
|
||||
"lora_request": active_lora["request"],
|
||||
}
|
||||
conn.send({"type": "call", "method": "generate", "kwargs": kwargs})
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
all_outputs = await asyncio.gather(
|
||||
*(loop.run_in_executor(None, safe_recv, conn) for conn in connections)
|
||||
)
|
||||
|
||||
all_outputs = [
|
||||
o for o, c in zip(all_outputs, chunked, strict=True) if c
|
||||
]
|
||||
for o in all_outputs:
|
||||
if isinstance(o, dict) and "error" in o:
|
||||
raise RuntimeError(f"vLLM worker error: {o['error']}")
|
||||
all_outputs = list(chain.from_iterable(all_outputs))
|
||||
|
||||
if not all_outputs:
|
||||
return {"choices": [], "model": script_args.model}
|
||||
|
||||
choices = []
|
||||
for i, output in enumerate(all_outputs):
|
||||
for j, out in enumerate(output.outputs):
|
||||
text = out.text
|
||||
# OpenAI-style `logprobs` block for text-completions:
|
||||
# { "tokens": [...], "token_logprobs": [...] }
|
||||
lp_block = None
|
||||
if out.logprobs:
|
||||
tokens_str: list[str] = []
|
||||
token_lps: list[float] = []
|
||||
for step in out.logprobs:
|
||||
chosen = next(iter(step.values()))
|
||||
tokens_str.append(getattr(chosen, "decoded_token", "") or "")
|
||||
token_lps.append(float(chosen.logprob))
|
||||
lp_block = {
|
||||
"tokens": tokens_str,
|
||||
"token_logprobs": token_lps,
|
||||
}
|
||||
|
||||
choice = {
|
||||
"index": i * n + j,
|
||||
"text": text,
|
||||
"finish_reason": "stop" if out.finish_reason == "stop" else "length",
|
||||
"logprobs": lp_block,
|
||||
# NeMo-Gym / retrace agent extras — preserved on the
|
||||
# choice so callers with raw-token pipelines don't
|
||||
# have to re-tokenize.
|
||||
"prompt_token_ids": output.prompt_token_ids,
|
||||
"generation_token_ids": list(out.token_ids),
|
||||
"generation_log_probs": (
|
||||
[float(next(iter(lp.values())).logprob) for lp in out.logprobs]
|
||||
if out.logprobs
|
||||
else []
|
||||
),
|
||||
}
|
||||
choices.append(choice)
|
||||
|
||||
prompt_tokens = len(all_outputs[0].prompt_token_ids) if all_outputs else 0
|
||||
completion_tokens = sum(
|
||||
len(out.token_ids) for o in all_outputs for out in o.outputs
|
||||
)
|
||||
|
||||
return {
|
||||
"id": f"cmpl-{uuid.uuid4().hex[:8]}",
|
||||
"object": "text_completion",
|
||||
"model": script_args.model,
|
||||
"choices": choices,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
# --- Weight sync endpoints (legacy fallback, same as TRL) ---
|
||||
|
||||
@app.post("/init_communicator/")
|
||||
|
||||
Reference in New Issue
Block a user