diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py index 5b430b31e..c89715719 100644 --- a/src/axolotl/cli/evaluate.py +++ b/src/axolotl/cli/evaluate.py @@ -12,7 +12,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.config import load_cfg -from axolotl.common.datasets import load_datasets, load_dpo_datasets +from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.evaluate import evaluate from axolotl.utils.dict import DictDefault @@ -35,7 +35,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: check_user_token() if cfg.rl: - dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 18f87acf5..760fe76fa 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -17,7 +17,7 @@ from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.config import load_cfg from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH -from axolotl.common.datasets import load_datasets, load_dpo_datasets +from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.utils.dict import DictDefault from axolotl.utils.trainer import disable_datasets_caching @@ -48,7 +48,7 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: with disable_datasets_caching(): if cfg.rl: - load_dpo_datasets(cfg=cfg, cli_args=cli_args) + load_preference_datasets(cfg=cfg, cli_args=cli_args) else: load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 320a40153..9e3ae1cc3 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -12,7 +12,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.config import load_cfg -from axolotl.common.datasets import load_datasets, load_dpo_datasets +from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.integrations.base import PluginManager from axolotl.train import train from axolotl.utils.dict import DictDefault @@ -35,7 +35,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: check_user_token() if cfg.rl: - dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index c0676b128..addfa0ab9 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -5,6 +5,7 @@ import dataclasses import hashlib import json import logging +import typing from functools import wraps from pathlib import Path from types import NoneType @@ -23,6 +24,25 @@ configure_logging() LOG = logging.getLogger(__name__) +def strip_optional_type(field_type: type | typing._SpecialForm | None): + """ + Extracts the non-`None` type from an `Optional` / `Union` type. + + Args: + field_type: Type of field for Axolotl CLI command. + + Returns: + If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise + returns the input type unchanged. + """ + if get_origin(field_type) is Union and type(None) in get_args(field_type): + field_type = next( + t for t in get_args(field_type) if not isinstance(t, NoneType) + ) + + return field_type + + def filter_none_kwargs(func: Callable) -> Callable: """ Wraps function to remove `None`-valued `kwargs`. @@ -49,18 +69,17 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable: Create Click options from the fields of a dataclass. Args: - config_class: Dataclass with fields to parse from the CLI + config_class: Dataclass with fields to parse from the CLI. + + Returns: + Function decorator for Axolotl CLI command. """ def decorator(function: Callable) -> Callable: # Process dataclass fields in reverse order for correct option ordering for field in reversed(dataclasses.fields(config_class)): - field_type = field.type + field_type = strip_optional_type(field.type) - if get_origin(field_type) is Union and type(None) in get_args(field_type): - field_type = next( - t for t in get_args(field_type) if not isinstance(t, NoneType) - ) if field_type == bool: field_name = field.name.replace("_", "-") option_name = f"--{field_name}/--no-{field_name}" @@ -89,12 +108,17 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable: Args: config_class: PyDantic model with fields to parse from the CLI + + Returns: + Function decorator for Axolotl CLI command. """ def decorator(function: Callable) -> Callable: # Process model fields in reverse order for correct option ordering for name, field in reversed(config_class.model_fields.items()): - if field.annotation == bool: + field_type = strip_optional_type(field.annotation) + + if field_type == bool: field_name = name.replace("_", "-") option_name = f"--{field_name}/--no-{field_name}" function = click.option( @@ -116,11 +140,11 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]: Build command list from base command and options. Args: - base_cmd: Command without options - options: Options to parse and append to base command + base_cmd: Command without options. + options: Options to parse and append to base command. Returns: - List of strings giving shell command + List of strings giving shell command. """ cmd = base_cmd.copy() @@ -146,13 +170,13 @@ def download_file( Download a single file and return its processing status. Args: - file_info: Tuple of (file_path, remote_sha) - raw_base_url: Base URL for raw GitHub content - dest_path: Local destination directory - dir_prefix: Directory prefix to filter files + file_info: Tuple of (file_path, remote_sha). + raw_base_url: Base URL for raw GitHub content. + dest_path: Local destination directory. + dir_prefix: Directory prefix to filter files. Returns: - Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged' + Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'. """ file_path, remote_sha = file_info raw_url = f"{raw_base_url}/{file_path}" @@ -201,9 +225,10 @@ def fetch_from_github( Only downloads files that don't exist locally or have changed. Args: - dir_prefix: Directory prefix to filter files (e.g., 'examples/', 'deepspeed_configs/') - dest_dir: Local destination directory - max_workers: Maximum number of concurrent downloads + dir_prefix: Directory prefix to filter files (e.g., 'examples/', + 'deepspeed_configs/'). + dest_dir: Local destination directory. + max_workers: Maximum number of concurrent downloads. """ api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1" raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main" diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index 12be18ea5..d07add29b 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -97,12 +97,10 @@ def load_datasets( ) -def load_dpo_datasets( +def load_preference_datasets( *, cfg: DictDefault, - cli_args: Union[ - PreprocessCliArgs, TrainerCliArgs - ], # pylint: disable=unused-argument + cli_args: Union[PreprocessCliArgs, TrainerCliArgs], ) -> TrainDatasetMeta: """ Loads one or more training or evaluation datasets for DPO training, calling diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 45c38ecff..2d0baceee 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -10,7 +10,7 @@ from pathlib import Path import pytest from axolotl.cli.args import TrainerCliArgs -from axolotl.common.datasets import load_dpo_datasets +from axolotl.common.datasets import load_preference_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -65,7 +65,7 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @@ -110,7 +110,7 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @@ -155,7 +155,7 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @@ -200,7 +200,7 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @@ -244,7 +244,7 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @@ -291,7 +291,7 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @@ -355,7 +355,7 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)