review comments

This commit is contained in:
Dan Saunders
2025-01-13 17:05:21 +00:00
parent 18a36b31ef
commit 3b82fc36ec
6 changed files with 59 additions and 36 deletions

View File

@@ -12,7 +12,7 @@ from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_dpo_datasets from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.evaluate import evaluate from axolotl.evaluate import evaluate
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -35,7 +35,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
check_user_token() check_user_token()
if cfg.rl: if cfg.rl:
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -17,7 +17,7 @@ from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_dpo_datasets from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import disable_datasets_caching from axolotl.utils.trainer import disable_datasets_caching
@@ -48,7 +48,7 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
with disable_datasets_caching(): with disable_datasets_caching():
if cfg.rl: if cfg.rl:
load_dpo_datasets(cfg=cfg, cli_args=cli_args) load_preference_datasets(cfg=cfg, cli_args=cli_args)
else: else:
load_datasets(cfg=cfg, cli_args=cli_args) load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_dpo_datasets from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.train import train from axolotl.train import train
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -35,7 +35,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
check_user_token() check_user_token()
if cfg.rl: if cfg.rl:
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -5,6 +5,7 @@ import dataclasses
import hashlib import hashlib
import json import json
import logging import logging
import typing
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from types import NoneType from types import NoneType
@@ -23,6 +24,25 @@ configure_logging()
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def strip_optional_type(field_type: type | typing._SpecialForm | None):
"""
Extracts the non-`None` type from an `Optional` / `Union` type.
Args:
field_type: Type of field for Axolotl CLI command.
Returns:
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
returns the input type unchanged.
"""
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
return field_type
def filter_none_kwargs(func: Callable) -> Callable: def filter_none_kwargs(func: Callable) -> Callable:
""" """
Wraps function to remove `None`-valued `kwargs`. Wraps function to remove `None`-valued `kwargs`.
@@ -49,18 +69,17 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
Create Click options from the fields of a dataclass. Create Click options from the fields of a dataclass.
Args: Args:
config_class: Dataclass with fields to parse from the CLI config_class: Dataclass with fields to parse from the CLI.
Returns:
Function decorator for Axolotl CLI command.
""" """
def decorator(function: Callable) -> Callable: def decorator(function: Callable) -> Callable:
# Process dataclass fields in reverse order for correct option ordering # Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)): for field in reversed(dataclasses.fields(config_class)):
field_type = field.type field_type = strip_optional_type(field.type)
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
if field_type == bool: if field_type == bool:
field_name = field.name.replace("_", "-") field_name = field.name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}" option_name = f"--{field_name}/--no-{field_name}"
@@ -89,12 +108,17 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
Args: Args:
config_class: PyDantic model with fields to parse from the CLI config_class: PyDantic model with fields to parse from the CLI
Returns:
Function decorator for Axolotl CLI command.
""" """
def decorator(function: Callable) -> Callable: def decorator(function: Callable) -> Callable:
# Process model fields in reverse order for correct option ordering # Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()): for name, field in reversed(config_class.model_fields.items()):
if field.annotation == bool: field_type = strip_optional_type(field.annotation)
if field_type == bool:
field_name = name.replace("_", "-") field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}" option_name = f"--{field_name}/--no-{field_name}"
function = click.option( function = click.option(
@@ -116,11 +140,11 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
Build command list from base command and options. Build command list from base command and options.
Args: Args:
base_cmd: Command without options base_cmd: Command without options.
options: Options to parse and append to base command options: Options to parse and append to base command.
Returns: Returns:
List of strings giving shell command List of strings giving shell command.
""" """
cmd = base_cmd.copy() cmd = base_cmd.copy()
@@ -146,13 +170,13 @@ def download_file(
Download a single file and return its processing status. Download a single file and return its processing status.
Args: Args:
file_info: Tuple of (file_path, remote_sha) file_info: Tuple of (file_path, remote_sha).
raw_base_url: Base URL for raw GitHub content raw_base_url: Base URL for raw GitHub content.
dest_path: Local destination directory dest_path: Local destination directory.
dir_prefix: Directory prefix to filter files dir_prefix: Directory prefix to filter files.
Returns: Returns:
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged' Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
""" """
file_path, remote_sha = file_info file_path, remote_sha = file_info
raw_url = f"{raw_base_url}/{file_path}" raw_url = f"{raw_base_url}/{file_path}"
@@ -201,9 +225,10 @@ def fetch_from_github(
Only downloads files that don't exist locally or have changed. Only downloads files that don't exist locally or have changed.
Args: Args:
dir_prefix: Directory prefix to filter files (e.g., 'examples/', 'deepspeed_configs/') dir_prefix: Directory prefix to filter files (e.g., 'examples/',
dest_dir: Local destination directory 'deepspeed_configs/').
max_workers: Maximum number of concurrent downloads dest_dir: Local destination directory.
max_workers: Maximum number of concurrent downloads.
""" """
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1" api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main" raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"

View File

@@ -97,12 +97,10 @@ def load_datasets(
) )
def load_dpo_datasets( def load_preference_datasets(
*, *,
cfg: DictDefault, cfg: DictDefault,
cli_args: Union[ cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
PreprocessCliArgs, TrainerCliArgs
], # pylint: disable=unused-argument
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
""" """
Loads one or more training or evaluation datasets for DPO training, calling Loads one or more training or evaluation datasets for DPO training, calling

View File

@@ -10,7 +10,7 @@ from pathlib import Path
import pytest import pytest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_dpo_datasets from axolotl.common.datasets import load_preference_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -65,7 +65,7 @@ class TestDPOLlamaLora(unittest.TestCase):
) )
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -110,7 +110,7 @@ class TestDPOLlamaLora(unittest.TestCase):
) )
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -155,7 +155,7 @@ class TestDPOLlamaLora(unittest.TestCase):
) )
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -200,7 +200,7 @@ class TestDPOLlamaLora(unittest.TestCase):
) )
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -244,7 +244,7 @@ class TestDPOLlamaLora(unittest.TestCase):
) )
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -291,7 +291,7 @@ class TestDPOLlamaLora(unittest.TestCase):
) )
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -355,7 +355,7 @@ class TestDPOLlamaLora(unittest.TestCase):
) )
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)