chore: refactor

This commit is contained in:
NanoCode012
2025-05-27 10:44:01 +07:00
parent cc87f910d3
commit 99c8859eb0

View File

@@ -1,7 +1,6 @@
"""Module for axolotl CLI command arguments.""" """Module for axolotl CLI command arguments."""
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional
@dataclass @dataclass
@@ -11,9 +10,9 @@ class PreprocessCliArgs:
debug: bool = field(default=False) debug: bool = field(default=False)
debug_text_only: bool = field(default=False) debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1) debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None) prompter: str | None = field(default=None)
download: Optional[bool] = field(default=True) download: bool | None = field(default=True)
iterable: Optional[bool] = field( iterable: bool | None = field(
default=None, default=None,
metadata={ metadata={
"help": "Use IterableDataset for streaming processing of large datasets" "help": "Use IterableDataset for streaming processing of large datasets"
@@ -29,29 +28,29 @@ class TrainerCliArgs:
debug_text_only: bool = field(default=False) debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0) debug_num_examples: int = field(default=0)
merge_lora: bool = field(default=False) merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None) prompter: str | None = field(default=None)
shard: bool = field(default=False) shard: bool = field(default=False)
main_process_port: Optional[int] = field(default=None) main_process_port: int | None = field(default=None)
num_processes: Optional[int] = field(default=None) num_processes: int | None = field(default=None)
@dataclass @dataclass
class VllmServeCliArgs: class VllmServeCliArgs:
"""Dataclass with CLI arguments for `axolotl vllm-serve` command.""" """Dataclass with CLI arguments for `axolotl vllm-serve` command."""
tensor_parallel_size: Optional[int] = field( tensor_parallel_size: int | None = field(
default=None, default=None,
metadata={"help": "Number of tensor parallel workers to use."}, metadata={"help": "Number of tensor parallel workers to use."},
) )
host: Optional[str] = field( host: str | None = field(
default=None, # nosec B104 default=None, # nosec B104
metadata={"help": "Host address to run the server on."}, metadata={"help": "Host address to run the server on."},
) )
port: Optional[int] = field( port: int | None = field(
default=None, default=None,
metadata={"help": "Port to run the server on."}, metadata={"help": "Port to run the server on."},
) )
gpu_memory_utilization: Optional[float] = field( gpu_memory_utilization: float | None = field(
default=None, default=None,
metadata={ metadata={
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV " "help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
@@ -60,14 +59,14 @@ class VllmServeCliArgs:
"out-of-memory (OOM) errors during initialization." "out-of-memory (OOM) errors during initialization."
}, },
) )
dtype: Optional[str] = field( dtype: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically " "help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
"determined based on the model configuration. Find the supported values in the vLLM documentation." "determined based on the model configuration. Find the supported values in the vLLM documentation."
}, },
) )
max_model_len: Optional[int] = field( max_model_len: int | None = field(
default=None, default=None,
metadata={ metadata={
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced " "help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
@@ -75,14 +74,14 @@ class VllmServeCliArgs:
"context size, which might be much larger than the KV cache, leading to inefficiencies." "context size, which might be much larger than the KV cache, leading to inefficiencies."
}, },
) )
enable_prefix_caching: Optional[bool] = field( enable_prefix_caching: bool | None = field(
default=None, default=None,
metadata={ metadata={
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the " "help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
"hardware support this feature." "hardware support this feature."
}, },
) )
serve_module: Optional[str] = field( serve_module: str | None = field(
default=None, default=None,
metadata={ metadata={
"help": "Module to serve. If not set, the default module will be used." "help": "Module to serve. If not set, the default module will be used."
@@ -103,4 +102,4 @@ class EvaluateCliArgs:
class InferenceCliArgs: class InferenceCliArgs:
"""Dataclass with CLI arguments for `axolotl inference` command.""" """Dataclass with CLI arguments for `axolotl inference` command."""
prompter: Optional[str] = field(default=None) prompter: str | None = field(default=None)