From 53391a10d73c67aa2c94adbbc1fcfebf6160f33b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 14 Apr 2026 15:52:02 +0000 Subject: [PATCH] vllm-serve-lora add /v1/completions route + worker pipe lock The LoRA vllm-serve wrapper only exposed /v1/chat/completions, but retrace's SWE agent server uses the token-id-aware /v1/completions endpoint so it can feed raw prompt_token_ids + track per-token logprobs across multi-turn rollouts. Add the route, mirroring the shape of /v1/chat/completions but routing to the vLLM worker's generate() method so prompt_token_ids are passed through as-is. Also add a worker_pipe_lock around conn.send/conn.recv. The multiprocessing.Connection to the vLLM worker is a single shared full-duplex pipe; concurrent HTTP requests interleave pickle frames on the wire and corrupt the stream (observed as UnpicklingError: pickle data was truncated, surfacing as 500s). The agent server fires ~8 concurrent rollout requests at once, so this was a hard blocker for any multi-concurrent workload. Serialize access to the pipe per-request round-trip. --- src/axolotl/scripts/vllm_serve_lora.py | 150 +++++++++++++++++++++++++ 1 file changed, 150 insertions(+) diff --git a/src/axolotl/scripts/vllm_serve_lora.py b/src/axolotl/scripts/vllm_serve_lora.py index 344c4327f..4f0ec12fe 100644 --- a/src/axolotl/scripts/vllm_serve_lora.py +++ b/src/axolotl/scripts/vllm_serve_lora.py @@ -320,6 +320,15 @@ def main(script_args: ScriptArguments): # --- Active LoRA state (shared across endpoints via closure) --- active_lora: dict = {"request": None} + # Serializes access to the worker pipe. The underlying + # multiprocessing.Connection is a single full-duplex stream shared + # across all HTTP handlers; concurrent requests interleave bytes on + # the wire and corrupt the pickle framing (seen as + # ``UnpicklingError: pickle data was truncated``). Any endpoint that + # does ``conn.send(...); conn.recv()`` MUST hold this lock across + # the round-trip so only one inflight call at a time per pipe. + worker_pipe_lock = asyncio.Lock() + # ------------------------------------------------------------------ # LoRA-specific endpoints # ------------------------------------------------------------------ @@ -631,6 +640,147 @@ def main(script_args: ScriptArguments): }, } + @app.post("/v1/completions") + async def openai_completions(request_body: dict): + """OpenAI-compatible text-completions endpoint. + + Accepts either a string ``prompt`` or a list-of-int + ``prompt_token_ids`` (as the text-completions spec allows). Routes + to the internal vLLM generate method with the active LoRA adapter + and returns an OpenAI /v1/completions-shaped response including + per-choice ``prompt_token_ids``, ``generation_token_ids``, and + ``generation_log_probs`` for NeMo Gym agents that need raw + tokens + logprobs. + """ + import uuid + + prompt_raw = request_body.get("prompt") + temperature = request_body.get("temperature", 1.0) + max_tokens = request_body.get("max_tokens", 512) + top_p = request_body.get("top_p", 1.0) + n = request_body.get("n", 1) + logprobs = request_body.get("logprobs") or 0 + stop_token_ids = request_body.get("stop_token_ids") or None + + # Accept either a string or a list[int] token id prompt. Lists + # must contain ints only (raise on lists of strings so callers get + # a clear error). Also accept [[int, int, ...]] nesting for the + # rare case callers pass a single-prompt batch. + if isinstance(prompt_raw, list) and prompt_raw and isinstance(prompt_raw[0], list): + prompt_raw = prompt_raw[0] + + if isinstance(prompt_raw, list): + prompt_dict = {"prompt_token_ids": prompt_raw} + elif isinstance(prompt_raw, str): + prompt_dict = {"prompt": prompt_raw} + else: + return { + "error": { + "message": ( + "prompt must be a string or a list of token ids" + ), + "type": "invalid_request", + } + } + + generation_kwargs: dict[str, Any] = { + "n": n, + "temperature": temperature, + "top_p": top_p, + "max_tokens": max_tokens, + "logprobs": logprobs, + } + if stop_token_ids: + generation_kwargs["stop_token_ids"] = stop_token_ids + sampling_params = SamplingParams( + **{k: v for k, v in generation_kwargs.items() if v is not None} + ) + + chunked = chunk_list([prompt_dict], script_args.data_parallel_size) + + # Hold the pipe lock across send+recv — concurrent requests would + # otherwise interleave pickle frames on the worker connection. + async with worker_pipe_lock: + for conn, chunk in zip(connections, chunked, strict=True): + if not chunk: + chunk = [{"prompt": ""}] + kwargs = { + "prompts": chunk, + "sampling_params": sampling_params, + "lora_request": active_lora["request"], + } + conn.send({"type": "call", "method": "generate", "kwargs": kwargs}) + + loop = asyncio.get_running_loop() + all_outputs = await asyncio.gather( + *(loop.run_in_executor(None, safe_recv, conn) for conn in connections) + ) + + all_outputs = [ + o for o, c in zip(all_outputs, chunked, strict=True) if c + ] + for o in all_outputs: + if isinstance(o, dict) and "error" in o: + raise RuntimeError(f"vLLM worker error: {o['error']}") + all_outputs = list(chain.from_iterable(all_outputs)) + + if not all_outputs: + return {"choices": [], "model": script_args.model} + + choices = [] + for i, output in enumerate(all_outputs): + for j, out in enumerate(output.outputs): + text = out.text + # OpenAI-style `logprobs` block for text-completions: + # { "tokens": [...], "token_logprobs": [...] } + lp_block = None + if out.logprobs: + tokens_str: list[str] = [] + token_lps: list[float] = [] + for step in out.logprobs: + chosen = next(iter(step.values())) + tokens_str.append(getattr(chosen, "decoded_token", "") or "") + token_lps.append(float(chosen.logprob)) + lp_block = { + "tokens": tokens_str, + "token_logprobs": token_lps, + } + + choice = { + "index": i * n + j, + "text": text, + "finish_reason": "stop" if out.finish_reason == "stop" else "length", + "logprobs": lp_block, + # NeMo-Gym / retrace agent extras — preserved on the + # choice so callers with raw-token pipelines don't + # have to re-tokenize. + "prompt_token_ids": output.prompt_token_ids, + "generation_token_ids": list(out.token_ids), + "generation_log_probs": ( + [float(next(iter(lp.values())).logprob) for lp in out.logprobs] + if out.logprobs + else [] + ), + } + choices.append(choice) + + prompt_tokens = len(all_outputs[0].prompt_token_ids) if all_outputs else 0 + completion_tokens = sum( + len(out.token_ids) for o in all_outputs for out in o.outputs + ) + + return { + "id": f"cmpl-{uuid.uuid4().hex[:8]}", + "object": "text_completion", + "model": script_args.model, + "choices": choices, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + # --- Weight sync endpoints (legacy fallback, same as TRL) --- @app.post("/init_communicator/")