typing fixes
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""Module for axolotl CLI command arguments."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -10,8 +11,8 @@ class PreprocessCliArgs:
|
||||
debug: bool = field(default=False)
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=1)
|
||||
prompter: str | None = field(default=None)
|
||||
download: bool | None = field(default=True)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
download: Optional[bool] = field(default=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -22,7 +23,7 @@ class TrainerCliArgs:
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=0)
|
||||
merge_lora: bool = field(default=False)
|
||||
prompter: str | None = field(default=None)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
shard: bool = field(default=False)
|
||||
|
||||
|
||||
@@ -39,4 +40,4 @@ class EvaluateCliArgs:
|
||||
class InferenceCliArgs:
|
||||
"""Dataclass with CLI arguments for `axolotl inference` command."""
|
||||
|
||||
prompter: str | None = field(default=None)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
|
||||
@@ -8,18 +8,7 @@ import logging
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from types import NoneType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
from typing import Any, Callable, Type, Union, get_args, get_origin
|
||||
|
||||
import click
|
||||
import requests
|
||||
@@ -64,7 +53,7 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
|
||||
config_class: Dataclass with fields to parse from the CLI
|
||||
"""
|
||||
|
||||
def decorator(function):
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process dataclass fields in reverse order for correct option ordering
|
||||
for field in reversed(dataclasses.fields(config_class)):
|
||||
field_type = field.type
|
||||
@@ -89,6 +78,7 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
@@ -102,7 +92,7 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
config_class: PyDantic model with fields to parse from the CLI
|
||||
"""
|
||||
|
||||
def decorator(function):
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process model fields in reverse order for correct option ordering
|
||||
for name, field in reversed(config_class.model_fields.items()):
|
||||
if field.annotation == bool:
|
||||
@@ -116,12 +106,13 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
|
||||
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Build command list from base command and options.
|
||||
|
||||
@@ -151,7 +142,7 @@ def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
|
||||
|
||||
def download_file(
|
||||
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
|
||||
) -> Tuple[str, str]:
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Download a single file and return its processing status.
|
||||
|
||||
@@ -204,7 +195,7 @@ def download_file(
|
||||
|
||||
|
||||
def fetch_from_github(
|
||||
dir_prefix: str, dest_dir: Optional[str] = None, max_workers: int = 5
|
||||
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
|
||||
) -> None:
|
||||
"""
|
||||
Sync files from a specific directory in the GitHub repository.
|
||||
@@ -238,7 +229,7 @@ def fetch_from_github(
|
||||
dest_path = Path(dest_dir) if dest_dir else default_dest
|
||||
|
||||
# Keep track of processed files for summary
|
||||
files_processed: Dict[str, List[str]] = {
|
||||
files_processed: dict[str, list[str]] = {
|
||||
"new": [],
|
||||
"updated": [],
|
||||
"unchanged": [],
|
||||
@@ -281,7 +272,7 @@ def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
inference: bool = False,
|
||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
|
||||
) -> tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
|
||||
"""
|
||||
Helper function for loading a model and tokenizer specified in the given `axolotl`
|
||||
config.
|
||||
|
||||
Reference in New Issue
Block a user