diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py index 9ddef5230..0618e07f1 100644 --- a/src/axolotl/cli/args.py +++ b/src/axolotl/cli/args.py @@ -1,6 +1,7 @@ """Module for axolotl CLI command arguments.""" from dataclasses import dataclass, field +from typing import Optional @dataclass @@ -10,8 +11,8 @@ class PreprocessCliArgs: debug: bool = field(default=False) debug_text_only: bool = field(default=False) debug_num_examples: int = field(default=1) - prompter: str | None = field(default=None) - download: bool | None = field(default=True) + prompter: Optional[str] = field(default=None) + download: Optional[bool] = field(default=True) @dataclass @@ -22,7 +23,7 @@ class TrainerCliArgs: debug_text_only: bool = field(default=False) debug_num_examples: int = field(default=0) merge_lora: bool = field(default=False) - prompter: str | None = field(default=None) + prompter: Optional[str] = field(default=None) shard: bool = field(default=False) @@ -39,4 +40,4 @@ class EvaluateCliArgs: class InferenceCliArgs: """Dataclass with CLI arguments for `axolotl inference` command.""" - prompter: str | None = field(default=None) + prompter: Optional[str] = field(default=None) diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index 435637688..d19d73246 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -8,18 +8,7 @@ import logging from functools import wraps from pathlib import Path from types import NoneType -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Tuple, - Type, - Union, - get_args, - get_origin, -) +from typing import Any, Callable, Type, Union, get_args, get_origin import click import requests @@ -64,7 +53,7 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable: config_class: Dataclass with fields to parse from the CLI """ - def decorator(function): + def decorator(function: Callable) -> Callable: # Process dataclass fields in reverse order for correct option ordering for field in reversed(dataclasses.fields(config_class)): field_type = field.type @@ -89,6 +78,7 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable: default=field.default, help=field.metadata.get("description"), )(function) + return function return decorator @@ -102,7 +92,7 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable: config_class: PyDantic model with fields to parse from the CLI """ - def decorator(function): + def decorator(function: Callable) -> Callable: # Process model fields in reverse order for correct option ordering for name, field in reversed(config_class.model_fields.items()): if field.annotation == bool: @@ -116,12 +106,13 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable: function = click.option( option_name, default=None, help=field.description )(function) + return function return decorator -def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]: +def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]: """ Build command list from base command and options. @@ -151,7 +142,7 @@ def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]: def download_file( file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str -) -> Tuple[str, str]: +) -> tuple[str, str]: """ Download a single file and return its processing status. @@ -204,7 +195,7 @@ def download_file( def fetch_from_github( - dir_prefix: str, dest_dir: Optional[str] = None, max_workers: int = 5 + dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5 ) -> None: """ Sync files from a specific directory in the GitHub repository. @@ -238,7 +229,7 @@ def fetch_from_github( dest_path = Path(dest_dir) if dest_dir else default_dest # Keep track of processed files for summary - files_processed: Dict[str, List[str]] = { + files_processed: dict[str, list[str]] = { "new": [], "updated": [], "unchanged": [], @@ -281,7 +272,7 @@ def load_model_and_tokenizer( *, cfg: DictDefault, inference: bool = False, -) -> Tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]: +) -> tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]: """ Helper function for loading a model and tokenizer specified in the given `axolotl` config.