typing fixes

This commit is contained in:
Dan Saunders
2025-01-10 17:48:28 +00:00
parent c9e37496cb
commit 705e7dc270
2 changed files with 15 additions and 23 deletions

View File

@@ -1,6 +1,7 @@
"""Module for axolotl CLI command arguments."""
from dataclasses import dataclass, field
from typing import Optional
@dataclass
@@ -10,8 +11,8 @@ class PreprocessCliArgs:
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1)
prompter: str | None = field(default=None)
download: bool | None = field(default=True)
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
@dataclass
@@ -22,7 +23,7 @@ class TrainerCliArgs:
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
merge_lora: bool = field(default=False)
prompter: str | None = field(default=None)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
@@ -39,4 +40,4 @@ class EvaluateCliArgs:
class InferenceCliArgs:
"""Dataclass with CLI arguments for `axolotl inference` command."""
prompter: str | None = field(default=None)
prompter: Optional[str] = field(default=None)

View File

@@ -8,18 +8,7 @@ import logging
from functools import wraps
from pathlib import Path
from types import NoneType
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
get_args,
get_origin,
)
from typing import Any, Callable, Type, Union, get_args, get_origin
import click
import requests
@@ -64,7 +53,7 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
config_class: Dataclass with fields to parse from the CLI
"""
def decorator(function):
def decorator(function: Callable) -> Callable:
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = field.type
@@ -89,6 +78,7 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
default=field.default,
help=field.metadata.get("description"),
)(function)
return function
return decorator
@@ -102,7 +92,7 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
config_class: PyDantic model with fields to parse from the CLI
"""
def decorator(function):
def decorator(function: Callable) -> Callable:
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
if field.annotation == bool:
@@ -116,12 +106,13 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
function = click.option(
option_name, default=None, help=field.description
)(function)
return function
return decorator
def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
"""
Build command list from base command and options.
@@ -151,7 +142,7 @@ def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
def download_file(
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
) -> Tuple[str, str]:
) -> tuple[str, str]:
"""
Download a single file and return its processing status.
@@ -204,7 +195,7 @@ def download_file(
def fetch_from_github(
dir_prefix: str, dest_dir: Optional[str] = None, max_workers: int = 5
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
) -> None:
"""
Sync files from a specific directory in the GitHub repository.
@@ -238,7 +229,7 @@ def fetch_from_github(
dest_path = Path(dest_dir) if dest_dir else default_dest
# Keep track of processed files for summary
files_processed: Dict[str, List[str]] = {
files_processed: dict[str, list[str]] = {
"new": [],
"updated": [],
"unchanged": [],
@@ -281,7 +272,7 @@ def load_model_and_tokenizer(
*,
cfg: DictDefault,
inference: bool = False,
) -> Tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
) -> tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
"""
Helper function for loading a model and tokenizer specified in the given `axolotl`
config.