lint
This commit is contained in:
@@ -666,9 +666,14 @@ def main(script_args: ScriptArguments):
|
||||
# must contain ints only (raise on lists of strings so callers get
|
||||
# a clear error). Also accept [[int, int, ...]] nesting for the
|
||||
# rare case callers pass a single-prompt batch.
|
||||
if isinstance(prompt_raw, list) and prompt_raw and isinstance(prompt_raw[0], list):
|
||||
if (
|
||||
isinstance(prompt_raw, list)
|
||||
and prompt_raw
|
||||
and isinstance(prompt_raw[0], list)
|
||||
):
|
||||
prompt_raw = prompt_raw[0]
|
||||
|
||||
prompt_dict: dict[str, Any] = {}
|
||||
if isinstance(prompt_raw, list):
|
||||
prompt_dict = {"prompt_token_ids": prompt_raw}
|
||||
elif isinstance(prompt_raw, str):
|
||||
@@ -676,9 +681,7 @@ def main(script_args: ScriptArguments):
|
||||
else:
|
||||
return {
|
||||
"error": {
|
||||
"message": (
|
||||
"prompt must be a string or a list of token ids"
|
||||
),
|
||||
"message": ("prompt must be a string or a list of token ids"),
|
||||
"type": "invalid_request",
|
||||
}
|
||||
}
|
||||
@@ -716,9 +719,7 @@ def main(script_args: ScriptArguments):
|
||||
*(loop.run_in_executor(None, safe_recv, conn) for conn in connections)
|
||||
)
|
||||
|
||||
all_outputs = [
|
||||
o for o, c in zip(all_outputs, chunked, strict=True) if c
|
||||
]
|
||||
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
|
||||
for o in all_outputs:
|
||||
if isinstance(o, dict) and "error" in o:
|
||||
raise RuntimeError(f"vLLM worker error: {o['error']}")
|
||||
@@ -749,7 +750,9 @@ def main(script_args: ScriptArguments):
|
||||
choice = {
|
||||
"index": i * n + j,
|
||||
"text": text,
|
||||
"finish_reason": "stop" if out.finish_reason == "stop" else "length",
|
||||
"finish_reason": "stop"
|
||||
if out.finish_reason == "stop"
|
||||
else "length",
|
||||
"logprobs": lp_block,
|
||||
# NeMo-Gym / retrace agent extras — preserved on the
|
||||
# choice so callers with raw-token pipelines don't
|
||||
|
||||
Reference in New Issue
Block a user