diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py index 72e61d1bb..83febd7f4 100644 --- a/src/axolotl/cli/args.py +++ b/src/axolotl/cli/args.py @@ -39,16 +39,16 @@ class TrainerCliArgs: class VllmServeCliArgs: """Dataclass with CLI arguments for `axolotl vllm-serve` command.""" - tensor_parallel_size: int = field( - default=1, + tensor_parallel_size: Optional[int] = field( + default=None, metadata={"help": "Number of tensor parallel workers to use."}, ) - host: str = field( - default="0.0.0.0", # nosec B104 + host: Optional[str] = field( + default=None, # nosec B104 metadata={"help": "Host address to run the server on."}, ) - port: int = field( - default=8000, + port: Optional[int] = field( + default=None, metadata={"help": "Port to run the server on."}, ) gpu_memory_utilization: Optional[float] = field( diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 219eced69..0d3e7b7d7 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -40,8 +40,8 @@ class GRPOStrategy: if trl.use_vllm: grpo_args_kwargs["use_vllm"] = trl.use_vllm - grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host - grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port + grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host + grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port if trl.vllm_server_timeout: grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout if trl.vllm_guided_decoding_regex: diff --git a/src/axolotl/utils/schemas/vllm.py b/src/axolotl/utils/schemas/vllm.py index bb1a4ba26..48441de5e 100644 --- a/src/axolotl/utils/schemas/vllm.py +++ b/src/axolotl/utils/schemas/vllm.py @@ -36,3 +36,11 @@ class VllmConfig(BaseModel): default=None, json_schema_extra={"description": "Enable prefix caching for VLLM"}, ) + host: str | None = Field( + default="0.0.0.0", # nosec B104 + json_schema_extra={"description": "Host for the vLLM server to start on"}, + ) + port: int | None = Field( + default=8000, + json_schema_extra={"description": "Port of the vLLM server to start on"}, + )