This commit is contained in:
Wing Lian
2026-04-14 17:26:00 -04:00
committed by Wing Lian
parent 53391a10d7
commit d4e9cf2eec

View File

@@ -666,9 +666,14 @@ def main(script_args: ScriptArguments):
# must contain ints only (raise on lists of strings so callers get
# a clear error). Also accept [[int, int, ...]] nesting for the
# rare case callers pass a single-prompt batch.
if isinstance(prompt_raw, list) and prompt_raw and isinstance(prompt_raw[0], list):
if (
isinstance(prompt_raw, list)
and prompt_raw
and isinstance(prompt_raw[0], list)
):
prompt_raw = prompt_raw[0]
prompt_dict: dict[str, Any] = {}
if isinstance(prompt_raw, list):
prompt_dict = {"prompt_token_ids": prompt_raw}
elif isinstance(prompt_raw, str):
@@ -676,9 +681,7 @@ def main(script_args: ScriptArguments):
else:
return {
"error": {
"message": (
"prompt must be a string or a list of token ids"
),
"message": ("prompt must be a string or a list of token ids"),
"type": "invalid_request",
}
}
@@ -716,9 +719,7 @@ def main(script_args: ScriptArguments):
*(loop.run_in_executor(None, safe_recv, conn) for conn in connections)
)
all_outputs = [
o for o, c in zip(all_outputs, chunked, strict=True) if c
]
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
for o in all_outputs:
if isinstance(o, dict) and "error" in o:
raise RuntimeError(f"vLLM worker error: {o['error']}")
@@ -749,7 +750,9 @@ def main(script_args: ScriptArguments):
choice = {
"index": i * n + j,
"text": text,
"finish_reason": "stop" if out.finish_reason == "stop" else "length",
"finish_reason": "stop"
if out.finish_reason == "stop"
else "length",
"logprobs": lp_block,
# NeMo-Gym / retrace agent extras — preserved on the
# choice so callers with raw-token pipelines don't