typing fixes

This commit is contained in:
Dan Saunders
2025-01-10 17:48:28 +00:00
parent c9e37496cb
commit 705e7dc270
2 changed files with 15 additions and 23 deletions

View File

@@ -1,6 +1,7 @@
"""Module for axolotl CLI command arguments.""" """Module for axolotl CLI command arguments."""
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional
@dataclass @dataclass
@@ -10,8 +11,8 @@ class PreprocessCliArgs:
debug: bool = field(default=False) debug: bool = field(default=False)
debug_text_only: bool = field(default=False) debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1) debug_num_examples: int = field(default=1)
prompter: str | None = field(default=None) prompter: Optional[str] = field(default=None)
download: bool | None = field(default=True) download: Optional[bool] = field(default=True)
@dataclass @dataclass
@@ -22,7 +23,7 @@ class TrainerCliArgs:
debug_text_only: bool = field(default=False) debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0) debug_num_examples: int = field(default=0)
merge_lora: bool = field(default=False) merge_lora: bool = field(default=False)
prompter: str | None = field(default=None) prompter: Optional[str] = field(default=None)
shard: bool = field(default=False) shard: bool = field(default=False)
@@ -39,4 +40,4 @@ class EvaluateCliArgs:
class InferenceCliArgs: class InferenceCliArgs:
"""Dataclass with CLI arguments for `axolotl inference` command.""" """Dataclass with CLI arguments for `axolotl inference` command."""
prompter: str | None = field(default=None) prompter: Optional[str] = field(default=None)

View File

@@ -8,18 +8,7 @@ import logging
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from types import NoneType from types import NoneType
from typing import ( from typing import Any, Callable, Type, Union, get_args, get_origin
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
get_args,
get_origin,
)
import click import click
import requests import requests
@@ -64,7 +53,7 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
config_class: Dataclass with fields to parse from the CLI config_class: Dataclass with fields to parse from the CLI
""" """
def decorator(function): def decorator(function: Callable) -> Callable:
# Process dataclass fields in reverse order for correct option ordering # Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)): for field in reversed(dataclasses.fields(config_class)):
field_type = field.type field_type = field.type
@@ -89,6 +78,7 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
default=field.default, default=field.default,
help=field.metadata.get("description"), help=field.metadata.get("description"),
)(function) )(function)
return function return function
return decorator return decorator
@@ -102,7 +92,7 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
config_class: PyDantic model with fields to parse from the CLI config_class: PyDantic model with fields to parse from the CLI
""" """
def decorator(function): def decorator(function: Callable) -> Callable:
# Process model fields in reverse order for correct option ordering # Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()): for name, field in reversed(config_class.model_fields.items()):
if field.annotation == bool: if field.annotation == bool:
@@ -116,12 +106,13 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
function = click.option( function = click.option(
option_name, default=None, help=field.description option_name, default=None, help=field.description
)(function) )(function)
return function return function
return decorator return decorator
def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]: def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
""" """
Build command list from base command and options. Build command list from base command and options.
@@ -151,7 +142,7 @@ def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
def download_file( def download_file(
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
) -> Tuple[str, str]: ) -> tuple[str, str]:
""" """
Download a single file and return its processing status. Download a single file and return its processing status.
@@ -204,7 +195,7 @@ def download_file(
def fetch_from_github( def fetch_from_github(
dir_prefix: str, dest_dir: Optional[str] = None, max_workers: int = 5 dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
) -> None: ) -> None:
""" """
Sync files from a specific directory in the GitHub repository. Sync files from a specific directory in the GitHub repository.
@@ -238,7 +229,7 @@ def fetch_from_github(
dest_path = Path(dest_dir) if dest_dir else default_dest dest_path = Path(dest_dir) if dest_dir else default_dest
# Keep track of processed files for summary # Keep track of processed files for summary
files_processed: Dict[str, List[str]] = { files_processed: dict[str, list[str]] = {
"new": [], "new": [],
"updated": [], "updated": [],
"unchanged": [], "unchanged": [],
@@ -281,7 +272,7 @@ def load_model_and_tokenizer(
*, *,
cfg: DictDefault, cfg: DictDefault,
inference: bool = False, inference: bool = False,
) -> Tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]: ) -> tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
""" """
Helper function for loading a model and tokenizer specified in the given `axolotl` Helper function for loading a model and tokenizer specified in the given `axolotl`
config. config.