fix missing host/port for vllm (#2543)
* fix missing host/port for vllm * set tensor parallel size so it doesn't always default to cli override
This commit is contained in:
@@ -39,16 +39,16 @@ class TrainerCliArgs:
|
|||||||
class VllmServeCliArgs:
|
class VllmServeCliArgs:
|
||||||
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
|
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
|
||||||
|
|
||||||
tensor_parallel_size: int = field(
|
tensor_parallel_size: Optional[int] = field(
|
||||||
default=1,
|
default=None,
|
||||||
metadata={"help": "Number of tensor parallel workers to use."},
|
metadata={"help": "Number of tensor parallel workers to use."},
|
||||||
)
|
)
|
||||||
host: str = field(
|
host: Optional[str] = field(
|
||||||
default="0.0.0.0", # nosec B104
|
default=None, # nosec B104
|
||||||
metadata={"help": "Host address to run the server on."},
|
metadata={"help": "Host address to run the server on."},
|
||||||
)
|
)
|
||||||
port: int = field(
|
port: Optional[int] = field(
|
||||||
default=8000,
|
default=None,
|
||||||
metadata={"help": "Port to run the server on."},
|
metadata={"help": "Port to run the server on."},
|
||||||
)
|
)
|
||||||
gpu_memory_utilization: Optional[float] = field(
|
gpu_memory_utilization: Optional[float] = field(
|
||||||
|
|||||||
@@ -40,8 +40,8 @@ class GRPOStrategy:
|
|||||||
|
|
||||||
if trl.use_vllm:
|
if trl.use_vllm:
|
||||||
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
||||||
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host
|
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host
|
||||||
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port
|
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port
|
||||||
if trl.vllm_server_timeout:
|
if trl.vllm_server_timeout:
|
||||||
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
|
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
|
||||||
if trl.vllm_guided_decoding_regex:
|
if trl.vllm_guided_decoding_regex:
|
||||||
|
|||||||
@@ -36,3 +36,11 @@ class VllmConfig(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Enable prefix caching for VLLM"},
|
json_schema_extra={"description": "Enable prefix caching for VLLM"},
|
||||||
)
|
)
|
||||||
|
host: str | None = Field(
|
||||||
|
default="0.0.0.0", # nosec B104
|
||||||
|
json_schema_extra={"description": "Host for the vLLM server to start on"},
|
||||||
|
)
|
||||||
|
port: int | None = Field(
|
||||||
|
default=8000,
|
||||||
|
json_schema_extra={"description": "Port of the vLLM server to start on"},
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user