review comments
This commit is contained in:
@@ -12,7 +12,7 @@ from axolotl.cli.args import TrainerCliArgs
|
|||||||
from axolotl.cli.art import print_axolotl_text_art
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.datasets import load_datasets, load_dpo_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
from axolotl.evaluate import evaluate
|
from axolotl.evaluate import evaluate
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -35,7 +35,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
check_user_token()
|
check_user_token()
|
||||||
|
|
||||||
if cfg.rl:
|
if cfg.rl:
|
||||||
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from axolotl.cli.art import print_axolotl_text_art
|
|||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.common.datasets import load_datasets, load_dpo_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.trainer import disable_datasets_caching
|
from axolotl.utils.trainer import disable_datasets_caching
|
||||||
|
|
||||||
@@ -48,7 +48,7 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
|||||||
|
|
||||||
with disable_datasets_caching():
|
with disable_datasets_caching():
|
||||||
if cfg.rl:
|
if cfg.rl:
|
||||||
load_dpo_datasets(cfg=cfg, cli_args=cli_args)
|
load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
else:
|
else:
|
||||||
load_datasets(cfg=cfg, cli_args=cli_args)
|
load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from axolotl.cli.args import TrainerCliArgs
|
|||||||
from axolotl.cli.art import print_axolotl_text_art
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.datasets import load_datasets, load_dpo_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -35,7 +35,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
check_user_token()
|
check_user_token()
|
||||||
|
|
||||||
if cfg.rl:
|
if cfg.rl:
|
||||||
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import dataclasses
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import typing
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import NoneType
|
from types import NoneType
|
||||||
@@ -23,6 +24,25 @@ configure_logging()
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def strip_optional_type(field_type: type | typing._SpecialForm | None):
|
||||||
|
"""
|
||||||
|
Extracts the non-`None` type from an `Optional` / `Union` type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_type: Type of field for Axolotl CLI command.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
|
||||||
|
returns the input type unchanged.
|
||||||
|
"""
|
||||||
|
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||||
|
field_type = next(
|
||||||
|
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||||
|
)
|
||||||
|
|
||||||
|
return field_type
|
||||||
|
|
||||||
|
|
||||||
def filter_none_kwargs(func: Callable) -> Callable:
|
def filter_none_kwargs(func: Callable) -> Callable:
|
||||||
"""
|
"""
|
||||||
Wraps function to remove `None`-valued `kwargs`.
|
Wraps function to remove `None`-valued `kwargs`.
|
||||||
@@ -49,18 +69,17 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
|
|||||||
Create Click options from the fields of a dataclass.
|
Create Click options from the fields of a dataclass.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_class: Dataclass with fields to parse from the CLI
|
config_class: Dataclass with fields to parse from the CLI.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Function decorator for Axolotl CLI command.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(function: Callable) -> Callable:
|
def decorator(function: Callable) -> Callable:
|
||||||
# Process dataclass fields in reverse order for correct option ordering
|
# Process dataclass fields in reverse order for correct option ordering
|
||||||
for field in reversed(dataclasses.fields(config_class)):
|
for field in reversed(dataclasses.fields(config_class)):
|
||||||
field_type = field.type
|
field_type = strip_optional_type(field.type)
|
||||||
|
|
||||||
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
|
||||||
field_type = next(
|
|
||||||
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
|
||||||
)
|
|
||||||
if field_type == bool:
|
if field_type == bool:
|
||||||
field_name = field.name.replace("_", "-")
|
field_name = field.name.replace("_", "-")
|
||||||
option_name = f"--{field_name}/--no-{field_name}"
|
option_name = f"--{field_name}/--no-{field_name}"
|
||||||
@@ -89,12 +108,17 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_class: PyDantic model with fields to parse from the CLI
|
config_class: PyDantic model with fields to parse from the CLI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Function decorator for Axolotl CLI command.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(function: Callable) -> Callable:
|
def decorator(function: Callable) -> Callable:
|
||||||
# Process model fields in reverse order for correct option ordering
|
# Process model fields in reverse order for correct option ordering
|
||||||
for name, field in reversed(config_class.model_fields.items()):
|
for name, field in reversed(config_class.model_fields.items()):
|
||||||
if field.annotation == bool:
|
field_type = strip_optional_type(field.annotation)
|
||||||
|
|
||||||
|
if field_type == bool:
|
||||||
field_name = name.replace("_", "-")
|
field_name = name.replace("_", "-")
|
||||||
option_name = f"--{field_name}/--no-{field_name}"
|
option_name = f"--{field_name}/--no-{field_name}"
|
||||||
function = click.option(
|
function = click.option(
|
||||||
@@ -116,11 +140,11 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
|||||||
Build command list from base command and options.
|
Build command list from base command and options.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
base_cmd: Command without options
|
base_cmd: Command without options.
|
||||||
options: Options to parse and append to base command
|
options: Options to parse and append to base command.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of strings giving shell command
|
List of strings giving shell command.
|
||||||
"""
|
"""
|
||||||
cmd = base_cmd.copy()
|
cmd = base_cmd.copy()
|
||||||
|
|
||||||
@@ -146,13 +170,13 @@ def download_file(
|
|||||||
Download a single file and return its processing status.
|
Download a single file and return its processing status.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_info: Tuple of (file_path, remote_sha)
|
file_info: Tuple of (file_path, remote_sha).
|
||||||
raw_base_url: Base URL for raw GitHub content
|
raw_base_url: Base URL for raw GitHub content.
|
||||||
dest_path: Local destination directory
|
dest_path: Local destination directory.
|
||||||
dir_prefix: Directory prefix to filter files
|
dir_prefix: Directory prefix to filter files.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'
|
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
|
||||||
"""
|
"""
|
||||||
file_path, remote_sha = file_info
|
file_path, remote_sha = file_info
|
||||||
raw_url = f"{raw_base_url}/{file_path}"
|
raw_url = f"{raw_base_url}/{file_path}"
|
||||||
@@ -201,9 +225,10 @@ def fetch_from_github(
|
|||||||
Only downloads files that don't exist locally or have changed.
|
Only downloads files that don't exist locally or have changed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dir_prefix: Directory prefix to filter files (e.g., 'examples/', 'deepspeed_configs/')
|
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
|
||||||
dest_dir: Local destination directory
|
'deepspeed_configs/').
|
||||||
max_workers: Maximum number of concurrent downloads
|
dest_dir: Local destination directory.
|
||||||
|
max_workers: Maximum number of concurrent downloads.
|
||||||
"""
|
"""
|
||||||
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
|
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
|
||||||
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
|
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
|
||||||
|
|||||||
@@ -97,12 +97,10 @@ def load_datasets(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_dpo_datasets(
|
def load_preference_datasets(
|
||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
cli_args: Union[
|
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
|
||||||
PreprocessCliArgs, TrainerCliArgs
|
|
||||||
], # pylint: disable=unused-argument
|
|
||||||
) -> TrainDatasetMeta:
|
) -> TrainDatasetMeta:
|
||||||
"""
|
"""
|
||||||
Loads one or more training or evaluation datasets for DPO training, calling
|
Loads one or more training or evaluation datasets for DPO training, calling
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_dpo_datasets
|
from axolotl.common.datasets import load_preference_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -65,7 +65,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
@@ -110,7 +110,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
@@ -155,7 +155,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
@@ -200,7 +200,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
@@ -244,7 +244,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
@@ -291,7 +291,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
@@ -355,7 +355,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user