chore: refactor

This commit is contained in:
NanoCode012
2025-05-27 10:44:01 +07:00
parent cc87f910d3
commit 99c8859eb0

View File

@@ -1,7 +1,6 @@
"""Module for axolotl CLI command arguments."""
from dataclasses import dataclass, field
from typing import Optional
@dataclass
@@ -11,9 +10,9 @@ class PreprocessCliArgs:
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(
prompter: str | None = field(default=None)
download: bool | None = field(default=True)
iterable: bool | None = field(
default=None,
metadata={
"help": "Use IterableDataset for streaming processing of large datasets"
@@ -29,29 +28,29 @@ class TrainerCliArgs:
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)
prompter: str | None = field(default=None)
shard: bool = field(default=False)
main_process_port: Optional[int] = field(default=None)
num_processes: Optional[int] = field(default=None)
main_process_port: int | None = field(default=None)
num_processes: int | None = field(default=None)
@dataclass
class VllmServeCliArgs:
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
tensor_parallel_size: Optional[int] = field(
tensor_parallel_size: int | None = field(
default=None,
metadata={"help": "Number of tensor parallel workers to use."},
)
host: Optional[str] = field(
host: str | None = field(
default=None, # nosec B104
metadata={"help": "Host address to run the server on."},
)
port: Optional[int] = field(
port: int | None = field(
default=None,
metadata={"help": "Port to run the server on."},
)
gpu_memory_utilization: Optional[float] = field(
gpu_memory_utilization: float | None = field(
default=None,
metadata={
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
@@ -60,14 +59,14 @@ class VllmServeCliArgs:
"out-of-memory (OOM) errors during initialization."
},
)
dtype: Optional[str] = field(
dtype: str | None = field(
default=None,
metadata={
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
"determined based on the model configuration. Find the supported values in the vLLM documentation."
},
)
max_model_len: Optional[int] = field(
max_model_len: int | None = field(
default=None,
metadata={
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
@@ -75,14 +74,14 @@ class VllmServeCliArgs:
"context size, which might be much larger than the KV cache, leading to inefficiencies."
},
)
enable_prefix_caching: Optional[bool] = field(
enable_prefix_caching: bool | None = field(
default=None,
metadata={
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
"hardware support this feature."
},
)
serve_module: Optional[str] = field(
serve_module: str | None = field(
default=None,
metadata={
"help": "Module to serve. If not set, the default module will be used."
@@ -103,4 +102,4 @@ class EvaluateCliArgs:
class InferenceCliArgs:
"""Dataclass with CLI arguments for `axolotl inference` command."""
prompter: Optional[str] = field(default=None)
prompter: str | None = field(default=None)