typing fixes
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
"""Module for axolotl CLI command arguments."""
|
"""Module for axolotl CLI command arguments."""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -10,8 +11,8 @@ class PreprocessCliArgs:
|
|||||||
debug: bool = field(default=False)
|
debug: bool = field(default=False)
|
||||||
debug_text_only: bool = field(default=False)
|
debug_text_only: bool = field(default=False)
|
||||||
debug_num_examples: int = field(default=1)
|
debug_num_examples: int = field(default=1)
|
||||||
prompter: str | None = field(default=None)
|
prompter: Optional[str] = field(default=None)
|
||||||
download: bool | None = field(default=True)
|
download: Optional[bool] = field(default=True)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -22,7 +23,7 @@ class TrainerCliArgs:
|
|||||||
debug_text_only: bool = field(default=False)
|
debug_text_only: bool = field(default=False)
|
||||||
debug_num_examples: int = field(default=0)
|
debug_num_examples: int = field(default=0)
|
||||||
merge_lora: bool = field(default=False)
|
merge_lora: bool = field(default=False)
|
||||||
prompter: str | None = field(default=None)
|
prompter: Optional[str] = field(default=None)
|
||||||
shard: bool = field(default=False)
|
shard: bool = field(default=False)
|
||||||
|
|
||||||
|
|
||||||
@@ -39,4 +40,4 @@ class EvaluateCliArgs:
|
|||||||
class InferenceCliArgs:
|
class InferenceCliArgs:
|
||||||
"""Dataclass with CLI arguments for `axolotl inference` command."""
|
"""Dataclass with CLI arguments for `axolotl inference` command."""
|
||||||
|
|
||||||
prompter: str | None = field(default=None)
|
prompter: Optional[str] = field(default=None)
|
||||||
|
|||||||
@@ -8,18 +8,7 @@ import logging
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import NoneType
|
from types import NoneType
|
||||||
from typing import (
|
from typing import Any, Callable, Type, Union, get_args, get_origin
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
get_args,
|
|
||||||
get_origin,
|
|
||||||
)
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import requests
|
import requests
|
||||||
@@ -64,7 +53,7 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
|
|||||||
config_class: Dataclass with fields to parse from the CLI
|
config_class: Dataclass with fields to parse from the CLI
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(function):
|
def decorator(function: Callable) -> Callable:
|
||||||
# Process dataclass fields in reverse order for correct option ordering
|
# Process dataclass fields in reverse order for correct option ordering
|
||||||
for field in reversed(dataclasses.fields(config_class)):
|
for field in reversed(dataclasses.fields(config_class)):
|
||||||
field_type = field.type
|
field_type = field.type
|
||||||
@@ -89,6 +78,7 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
|
|||||||
default=field.default,
|
default=field.default,
|
||||||
help=field.metadata.get("description"),
|
help=field.metadata.get("description"),
|
||||||
)(function)
|
)(function)
|
||||||
|
|
||||||
return function
|
return function
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
@@ -102,7 +92,7 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
|||||||
config_class: PyDantic model with fields to parse from the CLI
|
config_class: PyDantic model with fields to parse from the CLI
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(function):
|
def decorator(function: Callable) -> Callable:
|
||||||
# Process model fields in reverse order for correct option ordering
|
# Process model fields in reverse order for correct option ordering
|
||||||
for name, field in reversed(config_class.model_fields.items()):
|
for name, field in reversed(config_class.model_fields.items()):
|
||||||
if field.annotation == bool:
|
if field.annotation == bool:
|
||||||
@@ -116,12 +106,13 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
|||||||
function = click.option(
|
function = click.option(
|
||||||
option_name, default=None, help=field.description
|
option_name, default=None, help=field.description
|
||||||
)(function)
|
)(function)
|
||||||
|
|
||||||
return function
|
return function
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
|
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Build command list from base command and options.
|
Build command list from base command and options.
|
||||||
|
|
||||||
@@ -151,7 +142,7 @@ def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
|
|||||||
|
|
||||||
def download_file(
|
def download_file(
|
||||||
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
|
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
|
||||||
) -> Tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Download a single file and return its processing status.
|
Download a single file and return its processing status.
|
||||||
|
|
||||||
@@ -204,7 +195,7 @@ def download_file(
|
|||||||
|
|
||||||
|
|
||||||
def fetch_from_github(
|
def fetch_from_github(
|
||||||
dir_prefix: str, dest_dir: Optional[str] = None, max_workers: int = 5
|
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Sync files from a specific directory in the GitHub repository.
|
Sync files from a specific directory in the GitHub repository.
|
||||||
@@ -238,7 +229,7 @@ def fetch_from_github(
|
|||||||
dest_path = Path(dest_dir) if dest_dir else default_dest
|
dest_path = Path(dest_dir) if dest_dir else default_dest
|
||||||
|
|
||||||
# Keep track of processed files for summary
|
# Keep track of processed files for summary
|
||||||
files_processed: Dict[str, List[str]] = {
|
files_processed: dict[str, list[str]] = {
|
||||||
"new": [],
|
"new": [],
|
||||||
"updated": [],
|
"updated": [],
|
||||||
"unchanged": [],
|
"unchanged": [],
|
||||||
@@ -281,7 +272,7 @@ def load_model_and_tokenizer(
|
|||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
inference: bool = False,
|
inference: bool = False,
|
||||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
|
) -> tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
|
||||||
"""
|
"""
|
||||||
Helper function for loading a model and tokenizer specified in the given `axolotl`
|
Helper function for loading a model and tokenizer specified in the given `axolotl`
|
||||||
config.
|
config.
|
||||||
|
|||||||
Reference in New Issue
Block a user