review comments

This commit is contained in:
Dan Saunders
2025-01-10 17:27:03 +00:00
parent 2b7b37413d
commit 5ff1322f32
16 changed files with 130 additions and 158 deletions

42
src/axolotl/cli/args.py Normal file
View File

@@ -0,0 +1,42 @@
"""Module for axolotl CLI command arguments."""
from dataclasses import dataclass, field
@dataclass
class PreprocessCliArgs:
"""Dataclass with CLI arguments for `axolotl preprocess` command."""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1)
prompter: str | None = field(default=None)
download: bool | None = field(default=True)
@dataclass
class TrainerCliArgs:
"""Dataclass with CLI arguments for `axolotl train` command."""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
merge_lora: bool = field(default=False)
prompter: str | None = field(default=None)
shard: bool = field(default=False)
@dataclass
class EvaluateCliArgs:
"""Dataclass with CLI arguments for `axolotl evaluate` command."""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
@dataclass
class InferenceCliArgs:
"""Dataclass with CLI arguments for `axolotl inference` command."""
prompter: str | None = field(default=None)

View File

@@ -1,8 +1,5 @@
"""Axolotl ASCII logo utils."""
from art import text2art
from transformers.utils.import_utils import _is_package_available
from axolotl.utils.distributed import is_main_process
AXOLOTL_LOGO = """
@@ -20,38 +17,6 @@ AXOLOTL_LOGO = """
"""
def print_dep_versions():
"""Prints versions of various axolotl dependencies."""
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages)
if is_main_process():
print("*" * 40)
print("**** Axolotl Dependency Versions *****")
for pkg in packages:
pkg_version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}")
print("*" * 40)
def print_legacy_axolotl_text_art(suffix=None):
"""
Prints axolotl ASCII art and dependency versions.
Args:
suffix: Text to append to ASCII art text.
"""
font = "nancyj"
ascii_text = " axolotl"
if suffix:
ascii_text += f" x {suffix}"
ascii_art = text2art(ascii_text, font=font)
if is_main_process():
print(ascii_art)
print_dep_versions()
def print_axolotl_text_art():
"""Prints axolotl ASCII art."""
if is_main_process():

View File

@@ -14,7 +14,7 @@ configure_logging()
LOG = logging.getLogger(__name__)
def check_accelerate_default_config():
def check_accelerate_default_config() -> None:
"""Logs at warning level if no accelerate config file is found."""
if Path(config_args.default_yaml_config_file).exists():
LOG.warning(
@@ -22,7 +22,7 @@ def check_accelerate_default_config():
)
def check_user_token():
def check_user_token() -> bool:
"""Checks for HF user info. Check is skipped if HF_HUB_OFFLINE=1.
Returns:

View File

@@ -94,7 +94,7 @@ def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
raise err
def choose_config(path: Path):
def choose_config(path: Path) -> str:
"""
Helper method for choosing a `axolotl` config YAML file (considering only files
ending with `.yml` or `.yaml`). If more than one config file exists in the passed
@@ -152,7 +152,7 @@ def prepare_plugins(cfg: DictDefault):
plugin_manager.register(plugin_name)
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
"""
Loads the `axolotl` configuration stored at `config`, validates it, and performs
various setup.

View File

@@ -8,11 +8,11 @@ import fire
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets, load_rl_datasets
from axolotl.common.datasets import load_datasets, load_dpo_datasets
from axolotl.evaluate import evaluate
from axolotl.utils.dict import DictDefault
@@ -34,8 +34,8 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
check_accelerate_default_config()
check_user_token()
if cfg.rl: # and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
if cfg.rl:
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -13,9 +13,10 @@ import transformers
from dotenv import load_dotenv
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.cli.args import InferenceCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.common.cli import InferenceCliArgs, load_model_and_tokenizer
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
@@ -33,10 +34,11 @@ def get_multi_line_input() -> str:
Possibly multi-line, possibly empty stdin input as a string.
"""
print("Give me an instruction (Ctrl + D to submit): ")
instruction = ""
for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction

View File

@@ -7,6 +7,7 @@ from typing import Optional
import click
import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
@@ -14,7 +15,6 @@ from axolotl.cli.utils import (
fetch_from_github,
filter_none_kwargs,
)
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig

View File

@@ -8,9 +8,10 @@ import fire
import transformers
from dotenv import load_dotenv
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
@@ -31,10 +32,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
LOG.info("Running merge of LoRA with base model...")
model = model.merge_and_unload(progressbar=True)
try:
model.to(dtype=cfg.torch_dtype)
except RuntimeError:
pass
model.to(dtype=cfg.torch_dtype)
model.generation_config.do_sample = True
if cfg.local_rank == 0:

View File

@@ -24,9 +24,9 @@ from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.common.cli import TrainerCliArgs
LOG = logging.getLogger(__name__)

View File

@@ -12,12 +12,12 @@ from colorama import Fore
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM
from axolotl.cli.args import PreprocessCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_rl_datasets
from axolotl.common.datasets import load_datasets, load_dpo_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import disable_datasets_caching
@@ -47,8 +47,8 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
with disable_datasets_caching():
if cfg.rl: # and cfg.rl != "orpo":
load_rl_datasets(cfg=cfg, cli_args=cli_args)
if cfg.rl:
load_dpo_datasets(cfg=cfg, cli_args=cli_args)
else:
load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -8,11 +8,11 @@ import fire
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets, load_rl_datasets
from axolotl.common.datasets import load_datasets, load_dpo_datasets
from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils.dict import DictDefault
@@ -34,8 +34,8 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
check_accelerate_default_config()
check_user_token()
if cfg.rl: # and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
if cfg.rl:
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -8,16 +8,34 @@ import logging
from functools import wraps
from pathlib import Path
from types import NoneType
from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
get_args,
get_origin,
)
import click
import requests
from pydantic import BaseModel
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
configure_logging()
LOG = logging.getLogger(__name__)
def filter_none_kwargs(func):
def filter_none_kwargs(func: Callable) -> Callable:
"""
Wraps function to remove `None`-valued `kwargs`.
@@ -29,15 +47,16 @@ def filter_none_kwargs(func):
"""
@wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs) -> Callable:
"""Filters out `None`-valued `kwargs`."""
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
return func(*args, **filtered_kwargs)
return wrapper
def add_options_from_dataclass(config_class: Type[Any]):
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
"""
Create Click options from the fields of a dataclass.
@@ -75,7 +94,7 @@ def add_options_from_dataclass(config_class: Type[Any]):
return decorator
def add_options_from_config(config_class: Type[BaseModel]):
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
"""
Create Click options from the fields of a Pydantic model.
@@ -256,3 +275,28 @@ def fetch_from_github(
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
if files_processed["error"]:
LOG.info(f"Failed files: {len(files_processed['error'])}")
def load_model_and_tokenizer(
*,
cfg: DictDefault,
inference: bool = False,
) -> Tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
"""
Helper function for loading a model and tokenizer specified in the given `axolotl`
config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
inference: Boolean denoting inference mode.
Returns:
`transformers` model and tokenizer.
"""
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model...")
model, _ = load_model(cfg, tokenizer, inference=inference)
return model, tokenizer

View File

@@ -1,79 +0,0 @@
"""Shared module for CLI specific utilities."""
import logging
from dataclasses import dataclass, field
from typing import Any, Optional, Tuple
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
configure_logging()
LOG = logging.getLogger(__name__)
@dataclass
class PreprocessCliArgs:
"""Dataclass with CLI arguments for `axolotl preprocess` command."""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
@dataclass
class TrainerCliArgs:
"""Dataclass with CLI arguments for `axolotl train` command."""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
@dataclass
class EvaluateCliArgs:
"""Dataclass with CLI arguments for `axolotl evaluate` command."""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
@dataclass
class InferenceCliArgs:
"""Dataclass with CLI arguments for `axolotl inference` command."""
prompter: Optional[str] = field(default=None)
def load_model_and_tokenizer(
*,
cfg: DictDefault,
inference: bool = False,
) -> Tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
"""
Helper function for loading a model and tokenizer specified in the given `axolotl`
config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
inference: Boolean denoting inference mode.
Returns:
`transformers` model and tokenizer.
"""
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model...")
model, _ = load_model(cfg, tokenizer, inference=inference)
return model, tokenizer

View File

@@ -8,7 +8,7 @@ from typing import Optional, Union
from datasets import Dataset
from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.dict import DictDefault
@@ -27,7 +27,7 @@ class TrainDatasetMeta:
total_num_steps: Optional[int] = None
def sample_dataset(dataset: Dataset, num_samples: int):
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
"""
Randomly sample `num_samples` samples from `dataset`.
@@ -96,7 +96,7 @@ def load_datasets(
)
def load_rl_datasets(
def load_dpo_datasets(
*,
cfg: DictDefault,
cli_args: Union[
@@ -104,7 +104,7 @@ def load_rl_datasets(
], # pylint: disable=unused-argument
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets for RL training, calling
Loads one or more training or evaluation datasets for DPO training, calling
`axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug
information.

View File

@@ -66,7 +66,7 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
Evaluate a model on training and validation datasets
Args:
cfg: Config dictionary.
cfg: Dictionary mapping `axolotl` config keys to values.
dataset_meta: Dataset metadata containing training and evaluation datasets.
Returns:

View File

@@ -9,8 +9,8 @@ from pathlib import Path
import pytest
from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_rl_datasets
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_dpo_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -65,7 +65,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -110,7 +110,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -155,7 +155,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -200,7 +200,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -244,7 +244,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -291,7 +291,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -355,7 +355,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)