diff --git a/src/axolotl/scripts/vllm_serve_lora.py b/src/axolotl/scripts/vllm_serve_lora.py index 4f0ec12fe..ca2f743fc 100644 --- a/src/axolotl/scripts/vllm_serve_lora.py +++ b/src/axolotl/scripts/vllm_serve_lora.py @@ -666,9 +666,14 @@ def main(script_args: ScriptArguments): # must contain ints only (raise on lists of strings so callers get # a clear error). Also accept [[int, int, ...]] nesting for the # rare case callers pass a single-prompt batch. - if isinstance(prompt_raw, list) and prompt_raw and isinstance(prompt_raw[0], list): + if ( + isinstance(prompt_raw, list) + and prompt_raw + and isinstance(prompt_raw[0], list) + ): prompt_raw = prompt_raw[0] + prompt_dict: dict[str, Any] = {} if isinstance(prompt_raw, list): prompt_dict = {"prompt_token_ids": prompt_raw} elif isinstance(prompt_raw, str): @@ -676,9 +681,7 @@ def main(script_args: ScriptArguments): else: return { "error": { - "message": ( - "prompt must be a string or a list of token ids" - ), + "message": ("prompt must be a string or a list of token ids"), "type": "invalid_request", } } @@ -716,9 +719,7 @@ def main(script_args: ScriptArguments): *(loop.run_in_executor(None, safe_recv, conn) for conn in connections) ) - all_outputs = [ - o for o, c in zip(all_outputs, chunked, strict=True) if c - ] + all_outputs = [o for o, c in zip(all_outputs, chunked, strict=True) if c] for o in all_outputs: if isinstance(o, dict) and "error" in o: raise RuntimeError(f"vLLM worker error: {o['error']}") @@ -749,7 +750,9 @@ def main(script_args: ScriptArguments): choice = { "index": i * n + j, "text": text, - "finish_reason": "stop" if out.finish_reason == "stop" else "length", + "finish_reason": "stop" + if out.finish_reason == "stop" + else "length", "logprobs": lp_block, # NeMo-Gym / retrace agent extras — preserved on the # choice so callers with raw-token pipelines don't