review comments

This commit is contained in:
Dan Saunders
2025-01-13 17:05:21 +00:00
parent 18a36b31ef
commit 3b82fc36ec
6 changed files with 59 additions and 36 deletions

View File

@@ -12,7 +12,7 @@ from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_dpo_datasets
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.evaluate import evaluate
from axolotl.utils.dict import DictDefault
@@ -35,7 +35,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
check_user_token()
if cfg.rl:
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -17,7 +17,7 @@ from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_dpo_datasets
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import disable_datasets_caching
@@ -48,7 +48,7 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
with disable_datasets_caching():
if cfg.rl:
load_dpo_datasets(cfg=cfg, cli_args=cli_args)
load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_dpo_datasets
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils.dict import DictDefault
@@ -35,7 +35,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
check_user_token()
if cfg.rl:
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -5,6 +5,7 @@ import dataclasses
import hashlib
import json
import logging
import typing
from functools import wraps
from pathlib import Path
from types import NoneType
@@ -23,6 +24,25 @@ configure_logging()
LOG = logging.getLogger(__name__)
def strip_optional_type(field_type: type | typing._SpecialForm | None):
"""
Extracts the non-`None` type from an `Optional` / `Union` type.
Args:
field_type: Type of field for Axolotl CLI command.
Returns:
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
returns the input type unchanged.
"""
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
return field_type
def filter_none_kwargs(func: Callable) -> Callable:
"""
Wraps function to remove `None`-valued `kwargs`.
@@ -49,18 +69,17 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
Create Click options from the fields of a dataclass.
Args:
config_class: Dataclass with fields to parse from the CLI
config_class: Dataclass with fields to parse from the CLI.
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = field.type
field_type = strip_optional_type(field.type)
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
if field_type == bool:
field_name = field.name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
@@ -89,12 +108,17 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
Args:
config_class: PyDantic model with fields to parse from the CLI
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
if field.annotation == bool:
field_type = strip_optional_type(field.annotation)
if field_type == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
@@ -116,11 +140,11 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
Build command list from base command and options.
Args:
base_cmd: Command without options
options: Options to parse and append to base command
base_cmd: Command without options.
options: Options to parse and append to base command.
Returns:
List of strings giving shell command
List of strings giving shell command.
"""
cmd = base_cmd.copy()
@@ -146,13 +170,13 @@ def download_file(
Download a single file and return its processing status.
Args:
file_info: Tuple of (file_path, remote_sha)
raw_base_url: Base URL for raw GitHub content
dest_path: Local destination directory
dir_prefix: Directory prefix to filter files
file_info: Tuple of (file_path, remote_sha).
raw_base_url: Base URL for raw GitHub content.
dest_path: Local destination directory.
dir_prefix: Directory prefix to filter files.
Returns:
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
"""
file_path, remote_sha = file_info
raw_url = f"{raw_base_url}/{file_path}"
@@ -201,9 +225,10 @@ def fetch_from_github(
Only downloads files that don't exist locally or have changed.
Args:
dir_prefix: Directory prefix to filter files (e.g., 'examples/', 'deepspeed_configs/')
dest_dir: Local destination directory
max_workers: Maximum number of concurrent downloads
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
'deepspeed_configs/').
dest_dir: Local destination directory.
max_workers: Maximum number of concurrent downloads.
"""
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"

View File

@@ -97,12 +97,10 @@ def load_datasets(
)
def load_dpo_datasets(
def load_preference_datasets(
*,
cfg: DictDefault,
cli_args: Union[
PreprocessCliArgs, TrainerCliArgs
], # pylint: disable=unused-argument
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets for DPO training, calling

View File

@@ -10,7 +10,7 @@ from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_dpo_datasets
from axolotl.common.datasets import load_preference_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -65,7 +65,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -110,7 +110,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -155,7 +155,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -200,7 +200,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -244,7 +244,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -291,7 +291,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -355,7 +355,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)