chore: refactor
This commit is contained in:
@@ -1,7 +1,6 @@
|
|||||||
"""Module for axolotl CLI command arguments."""
|
"""Module for axolotl CLI command arguments."""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -11,9 +10,9 @@ class PreprocessCliArgs:
|
|||||||
debug: bool = field(default=False)
|
debug: bool = field(default=False)
|
||||||
debug_text_only: bool = field(default=False)
|
debug_text_only: bool = field(default=False)
|
||||||
debug_num_examples: int = field(default=1)
|
debug_num_examples: int = field(default=1)
|
||||||
prompter: Optional[str] = field(default=None)
|
prompter: str | None = field(default=None)
|
||||||
download: Optional[bool] = field(default=True)
|
download: bool | None = field(default=True)
|
||||||
iterable: Optional[bool] = field(
|
iterable: bool | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Use IterableDataset for streaming processing of large datasets"
|
"help": "Use IterableDataset for streaming processing of large datasets"
|
||||||
@@ -29,29 +28,29 @@ class TrainerCliArgs:
|
|||||||
debug_text_only: bool = field(default=False)
|
debug_text_only: bool = field(default=False)
|
||||||
debug_num_examples: int = field(default=0)
|
debug_num_examples: int = field(default=0)
|
||||||
merge_lora: bool = field(default=False)
|
merge_lora: bool = field(default=False)
|
||||||
prompter: Optional[str] = field(default=None)
|
prompter: str | None = field(default=None)
|
||||||
shard: bool = field(default=False)
|
shard: bool = field(default=False)
|
||||||
main_process_port: Optional[int] = field(default=None)
|
main_process_port: int | None = field(default=None)
|
||||||
num_processes: Optional[int] = field(default=None)
|
num_processes: int | None = field(default=None)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VllmServeCliArgs:
|
class VllmServeCliArgs:
|
||||||
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
|
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
|
||||||
|
|
||||||
tensor_parallel_size: Optional[int] = field(
|
tensor_parallel_size: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Number of tensor parallel workers to use."},
|
metadata={"help": "Number of tensor parallel workers to use."},
|
||||||
)
|
)
|
||||||
host: Optional[str] = field(
|
host: str | None = field(
|
||||||
default=None, # nosec B104
|
default=None, # nosec B104
|
||||||
metadata={"help": "Host address to run the server on."},
|
metadata={"help": "Host address to run the server on."},
|
||||||
)
|
)
|
||||||
port: Optional[int] = field(
|
port: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Port to run the server on."},
|
metadata={"help": "Port to run the server on."},
|
||||||
)
|
)
|
||||||
gpu_memory_utilization: Optional[float] = field(
|
gpu_memory_utilization: float | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
|
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
|
||||||
@@ -60,14 +59,14 @@ class VllmServeCliArgs:
|
|||||||
"out-of-memory (OOM) errors during initialization."
|
"out-of-memory (OOM) errors during initialization."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
dtype: Optional[str] = field(
|
dtype: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
|
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
|
||||||
"determined based on the model configuration. Find the supported values in the vLLM documentation."
|
"determined based on the model configuration. Find the supported values in the vLLM documentation."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_model_len: Optional[int] = field(
|
max_model_len: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
|
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
|
||||||
@@ -75,14 +74,14 @@ class VllmServeCliArgs:
|
|||||||
"context size, which might be much larger than the KV cache, leading to inefficiencies."
|
"context size, which might be much larger than the KV cache, leading to inefficiencies."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
enable_prefix_caching: Optional[bool] = field(
|
enable_prefix_caching: bool | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
|
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
|
||||||
"hardware support this feature."
|
"hardware support this feature."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
serve_module: Optional[str] = field(
|
serve_module: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Module to serve. If not set, the default module will be used."
|
"help": "Module to serve. If not set, the default module will be used."
|
||||||
@@ -103,4 +102,4 @@ class EvaluateCliArgs:
|
|||||||
class InferenceCliArgs:
|
class InferenceCliArgs:
|
||||||
"""Dataclass with CLI arguments for `axolotl inference` command."""
|
"""Dataclass with CLI arguments for `axolotl inference` command."""
|
||||||
|
|
||||||
prompter: Optional[str] = field(default=None)
|
prompter: str | None = field(default=None)
|
||||||
|
|||||||
Reference in New Issue
Block a user