lint
This commit is contained in:
@@ -666,9 +666,14 @@ def main(script_args: ScriptArguments):
|
|||||||
# must contain ints only (raise on lists of strings so callers get
|
# must contain ints only (raise on lists of strings so callers get
|
||||||
# a clear error). Also accept [[int, int, ...]] nesting for the
|
# a clear error). Also accept [[int, int, ...]] nesting for the
|
||||||
# rare case callers pass a single-prompt batch.
|
# rare case callers pass a single-prompt batch.
|
||||||
if isinstance(prompt_raw, list) and prompt_raw and isinstance(prompt_raw[0], list):
|
if (
|
||||||
|
isinstance(prompt_raw, list)
|
||||||
|
and prompt_raw
|
||||||
|
and isinstance(prompt_raw[0], list)
|
||||||
|
):
|
||||||
prompt_raw = prompt_raw[0]
|
prompt_raw = prompt_raw[0]
|
||||||
|
|
||||||
|
prompt_dict: dict[str, Any] = {}
|
||||||
if isinstance(prompt_raw, list):
|
if isinstance(prompt_raw, list):
|
||||||
prompt_dict = {"prompt_token_ids": prompt_raw}
|
prompt_dict = {"prompt_token_ids": prompt_raw}
|
||||||
elif isinstance(prompt_raw, str):
|
elif isinstance(prompt_raw, str):
|
||||||
@@ -676,9 +681,7 @@ def main(script_args: ScriptArguments):
|
|||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
"error": {
|
"error": {
|
||||||
"message": (
|
"message": ("prompt must be a string or a list of token ids"),
|
||||||
"prompt must be a string or a list of token ids"
|
|
||||||
),
|
|
||||||
"type": "invalid_request",
|
"type": "invalid_request",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -716,9 +719,7 @@ def main(script_args: ScriptArguments):
|
|||||||
*(loop.run_in_executor(None, safe_recv, conn) for conn in connections)
|
*(loop.run_in_executor(None, safe_recv, conn) for conn in connections)
|
||||||
)
|
)
|
||||||
|
|
||||||
all_outputs = [
|
all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c]
|
||||||
o for o, c in zip(all_outputs, chunked, strict=True) if c
|
|
||||||
]
|
|
||||||
for o in all_outputs:
|
for o in all_outputs:
|
||||||
if isinstance(o, dict) and "error" in o:
|
if isinstance(o, dict) and "error" in o:
|
||||||
raise RuntimeError(f"vLLM worker error: {o['error']}")
|
raise RuntimeError(f"vLLM worker error: {o['error']}")
|
||||||
@@ -749,7 +750,9 @@ def main(script_args: ScriptArguments):
|
|||||||
choice = {
|
choice = {
|
||||||
"index": i * n + j,
|
"index": i * n + j,
|
||||||
"text": text,
|
"text": text,
|
||||||
"finish_reason": "stop" if out.finish_reason == "stop" else "length",
|
"finish_reason": "stop"
|
||||||
|
if out.finish_reason == "stop"
|
||||||
|
else "length",
|
||||||
"logprobs": lp_block,
|
"logprobs": lp_block,
|
||||||
# NeMo-Gym / retrace agent extras — preserved on the
|
# NeMo-Gym / retrace agent extras — preserved on the
|
||||||
# choice so callers with raw-token pipelines don't
|
# choice so callers with raw-token pipelines don't
|
||||||
|
|||||||
Reference in New Issue
Block a user