chore: refactor
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
"""Module for axolotl CLI command arguments."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -11,9 +10,9 @@ class PreprocessCliArgs:
|
||||
debug: bool = field(default=False)
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=1)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
download: Optional[bool] = field(default=True)
|
||||
iterable: Optional[bool] = field(
|
||||
prompter: str | None = field(default=None)
|
||||
download: bool | None = field(default=True)
|
||||
iterable: bool | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Use IterableDataset for streaming processing of large datasets"
|
||||
@@ -29,29 +28,29 @@ class TrainerCliArgs:
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=0)
|
||||
merge_lora: bool = field(default=False)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
prompter: str | None = field(default=None)
|
||||
shard: bool = field(default=False)
|
||||
main_process_port: Optional[int] = field(default=None)
|
||||
num_processes: Optional[int] = field(default=None)
|
||||
main_process_port: int | None = field(default=None)
|
||||
num_processes: int | None = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VllmServeCliArgs:
|
||||
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
|
||||
|
||||
tensor_parallel_size: Optional[int] = field(
|
||||
tensor_parallel_size: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of tensor parallel workers to use."},
|
||||
)
|
||||
host: Optional[str] = field(
|
||||
host: str | None = field(
|
||||
default=None, # nosec B104
|
||||
metadata={"help": "Host address to run the server on."},
|
||||
)
|
||||
port: Optional[int] = field(
|
||||
port: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Port to run the server on."},
|
||||
)
|
||||
gpu_memory_utilization: Optional[float] = field(
|
||||
gpu_memory_utilization: float | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
|
||||
@@ -60,14 +59,14 @@ class VllmServeCliArgs:
|
||||
"out-of-memory (OOM) errors during initialization."
|
||||
},
|
||||
)
|
||||
dtype: Optional[str] = field(
|
||||
dtype: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
|
||||
"determined based on the model configuration. Find the supported values in the vLLM documentation."
|
||||
},
|
||||
)
|
||||
max_model_len: Optional[int] = field(
|
||||
max_model_len: int | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
|
||||
@@ -75,14 +74,14 @@ class VllmServeCliArgs:
|
||||
"context size, which might be much larger than the KV cache, leading to inefficiencies."
|
||||
},
|
||||
)
|
||||
enable_prefix_caching: Optional[bool] = field(
|
||||
enable_prefix_caching: bool | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
|
||||
"hardware support this feature."
|
||||
},
|
||||
)
|
||||
serve_module: Optional[str] = field(
|
||||
serve_module: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Module to serve. If not set, the default module will be used."
|
||||
@@ -103,4 +102,4 @@ class EvaluateCliArgs:
|
||||
class InferenceCliArgs:
|
||||
"""Dataclass with CLI arguments for `axolotl inference` command."""
|
||||
|
||||
prompter: Optional[str] = field(default=None)
|
||||
prompter: str | None = field(default=None)
|
||||
|
||||
Reference in New Issue
Block a user