From 6e72baf287b94f0b557178536e08afe7f46d584c Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 8 Jan 2025 19:15:03 +0000 Subject: [PATCH] continued cleanup and documentation --- src/axolotl/cli/art.py | 7 +- src/axolotl/cli/config.py | 69 ++++++++++++++--- src/axolotl/cli/evaluate.py | 21 ++++- src/axolotl/cli/inference.py | 7 +- src/axolotl/cli/main.py | 35 +-------- src/axolotl/cli/merge_lora.py | 6 +- src/axolotl/cli/merge_sharded_fsdp_weights.py | 37 ++++++--- src/axolotl/cli/preprocess.py | 4 +- src/axolotl/cli/shard.py | 33 -------- src/axolotl/cli/train.py | 8 +- src/axolotl/{cli => common}/datasets.py | 74 ++++++++++++++---- src/axolotl/train.py | 14 +--- tests/cli/test_cli_shard.py | 76 ------------------- .../integrations/test_cut_cross_entropy.py | 10 +-- tests/e2e/integrations/test_liger.py | 5 +- tests/e2e/patched/test_4d_multipack_llama.py | 6 +- tests/e2e/patched/test_fa_xentropy.py | 4 +- tests/e2e/patched/test_falcon_samplepack.py | 6 +- tests/e2e/patched/test_fused_llama.py | 4 +- tests/e2e/patched/test_llama_s2_attention.py | 6 +- .../e2e/patched/test_lora_llama_multipack.py | 6 +- tests/e2e/patched/test_mistral_samplepack.py | 6 +- tests/e2e/patched/test_mixtral_samplepack.py | 6 +- tests/e2e/patched/test_model_patches.py | 7 +- tests/e2e/patched/test_phi_multipack.py | 6 +- tests/e2e/patched/test_resume.py | 6 +- tests/e2e/patched/test_unsloth_qlora.py | 8 +- tests/e2e/test_dpo.py | 16 ++-- tests/e2e/test_embeddings_lr.py | 6 +- tests/e2e/test_falcon.py | 8 +- tests/e2e/test_llama.py | 8 +- tests/e2e/test_llama_pretrain.py | 4 +- tests/e2e/test_llama_vision.py | 6 +- tests/e2e/test_lora_llama.py | 4 +- tests/e2e/test_mamba.py | 7 +- tests/e2e/test_mistral.py | 6 +- tests/e2e/test_mixtral.py | 12 +-- tests/e2e/test_optimizers.py | 8 +- tests/e2e/test_packing_loss.py | 4 +- tests/e2e/test_phi.py | 6 +- tests/e2e/test_relora_llama.py | 4 +- tests/e2e/test_reward_model_llama.py | 4 +- 42 files changed, 280 insertions(+), 300 deletions(-) delete mode 100644 src/axolotl/cli/shard.py rename src/axolotl/{cli => common}/datasets.py (55%) delete mode 100644 tests/cli/test_cli_shard.py diff --git a/src/axolotl/cli/art.py b/src/axolotl/cli/art.py index 90c6c044e..d6b225b9f 100644 --- a/src/axolotl/cli/art.py +++ b/src/axolotl/cli/art.py @@ -34,7 +34,12 @@ def print_dep_versions(): def print_legacy_axolotl_text_art(suffix=None): - """Prints axolotl ASCII art and dependency versions.""" + """ + Prints axolotl ASCII art and dependency versions. + + Args: + suffix: Text to append to ASCII art text. + """ font = "nancyj" ascii_text = " axolotl" if suffix: diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 44f628e6a..1df664ff8 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -28,7 +28,24 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars LOG = logging.getLogger(__name__) -def check_remote_config(config: Union[str, Path]): +def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: + """ + First, determines if the passed config is a valid HTTPS URL. Then, attempts to query + for it and parse its content, first as JSON, then as YAML (YAML is preferred). + Finally, the parsed content is written to a local file and its path is returned. + + Args: + config: HTTPS URL to a YAML or JSON file. + + Returns: + Either the original `config` if it's not a valid HTTPS URL, or the path to the + downloaded remote config. + + Raises: + ValueError: If the remote configuration is neither valid JSON or YAML. + RuntimeError: If some request-related exception occurs from the file download. + Exception: Catch-all for any other exception. + """ # Check if the config is a valid HTTPS URL to a .yml or .yaml file if not (isinstance(config, str) and config.startswith("https://")): return config # Return the original value if it's not a valid URL @@ -42,9 +59,12 @@ def check_remote_config(config: Union[str, Path]): content = response.content try: - # Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML + # Try parsing as JSON first to catch cases where JSON content is mistakenly + # considered YAML. json.loads(content) - # Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link + + # Log a warning but do not raise an error; JSON is technically valid YAML. + # This can happen when you forget to point to a raw GitHub link. LOG.warning( f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended." ) @@ -75,7 +95,22 @@ def check_remote_config(config: Union[str, Path]): def choose_config(path: Path): - yaml_files = list(path.glob("*.yml")) + """ + Helper method for choosing a `axolotl` config YAML file (considering only files + ending with `.yml` or `.yaml`). If more than one config file exists in the passed + `path`, the user is prompted to choose one. + + Args: + path: Directory in which config file(s) are stored. + + Returns: + Path to either (1) the sole YAML file, or (2) if more than one YAML files exist, + the user-selected YAML file. + + Raises: + ValueError: If no YAML files are found in the given `path`. + """ + yaml_files = list(path.glob("*.yml")) + list(path.glob("*.yaml")) if not yaml_files: raise ValueError( @@ -104,11 +139,13 @@ def choose_config(path: Path): return chosen_file -def prepare_plugins(cfg): - """ - Prepare the plugins for the configuration +def prepare_plugins(cfg: DictDefault): """ + Registers the plugins for the given configuration. + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + """ if cfg.get("plugins"): plugin_manager = PluginManager.get_instance() for plugin_name in cfg["plugins"]: @@ -116,15 +153,27 @@ def prepare_plugins(cfg): def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): + """ + Loads the `axolotl` configuration stored at `config`, validates it, and performs + various setup. + + Args: + config: Path (local or remote) to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + + Returns: + `DictDefault` mapping configuration keys to values. + """ config = check_remote_config(config) if Path(config).is_dir(): config = choose_config(Path(config)) - # load the config from the yaml file + # Load the config from the yaml file with open(config, encoding="utf-8") as file: cfg: DictDefault = DictDefault(yaml.safe_load(file)) - # if there are any options passed in the cli, if it is something that seems valid from the yaml, - # then overwrite the value + + # If there are any options passed in the cli, if it is something that seems valid + # from the yaml, then overwrite the value cfg_keys = cfg.keys() for k, _ in kwargs.items(): # if not strict, allow writing to cfg even if it's not in the yml already diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py index 719dc9650..981d0fa64 100644 --- a/src/axolotl/cli/evaluate.py +++ b/src/axolotl/cli/evaluate.py @@ -11,14 +11,24 @@ from transformers.hf_argparser import HfArgumentParser from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.config import load_cfg -from axolotl.cli.datasets import load_datasets, load_rl_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets, load_rl_datasets from axolotl.evaluate import evaluate +from axolotl.utils.dict import DictDefault LOG = logging.getLogger(__name__) -def do_evaluate(cfg, cli_args) -> None: +def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: + """ + Evaluates a `transformers` model by first loading the dataset(s) specified in the + `axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes + evaluation metrics on the given dataset(s) and writes them to disk. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: CLI arguments. + """ # pylint: disable=duplicate-code print_axolotl_text_art() check_accelerate_default_config() @@ -33,6 +43,13 @@ def do_evaluate(cfg, cli_args) -> None: def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: + """ + Parses `axolotl` config, CLI args, and calls `do_evaluate`. + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + """ # pylint: disable=duplicate-code parsed_cfg = load_cfg(config, **kwargs) parser = HfArgumentParser(TrainerCliArgs) diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index e950b3528..88cd6e64b 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -52,7 +52,7 @@ def do_inference( Args: cfg: Dictionary mapping `axolotl` config keys to values. - cli_args: Training-specific CLI arguments. + cli_args: Inference-specific CLI arguments. """ model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True) prompter = cli_args.prompter @@ -145,7 +145,7 @@ def do_inference_gradio( Args: cfg: Dictionary mapping `axolotl` config keys to values. - cli_args: Training-specific CLI arguments. + cli_args: Inference-specific CLI arguments. """ import gradio as gr @@ -246,8 +246,7 @@ def do_cli( config: Union[Path, str] = Path("examples/"), gradio: bool = False, **kwargs ) -> None: """ - Parses axolotl config, training-specific CLI args, and calls `do_inference` or - `do_inference_gradio` as a subroutine. + Parses axolotl config, CLI args, and calls `do_inference` or `do_inference_gradio`. Args: config: Path to `axolotl` config YAML file. diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 73a291fed..f82d623ce 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -28,7 +28,7 @@ def cli(): @click.argument("config", type=click.Path(exists=True, path_type=str)) @add_options_from_dataclass(PreprocessCliArgs) @add_options_from_config(AxolotlInputConfig) -def preprocess(config: str, **kwargs): +def preprocess(config: str, **kwargs) -> None: """ Preprocess datasets before training. @@ -148,39 +148,6 @@ def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None: do_cli(config=config, gradio=gradio, **kwargs) -@cli.command() -@click.argument("config", type=click.Path(exists=True, path_type=str)) -@click.option( - "--accelerate/--no-accelerate", - default=False, - help="Use accelerate launch for multi-GPU operations", -) -@add_options_from_dataclass(TrainerCliArgs) -@add_options_from_config(AxolotlInputConfig) -def shard(config: str, accelerate: bool, **kwargs) -> None: - """ - Shard model weights into chunks. - - Args: - config: Path to `axolotl` config YAML file. - accelerate: Whether to use `accelerate` launcher. - kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` - config options. - """ - kwargs = {k: v for k, v in kwargs.items() if v is not None} - - if accelerate: - base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.shard"] - if config: - base_cmd.append(config) - cmd = build_command(base_cmd, kwargs) - subprocess.run(cmd, check=True) # nosec B603 - else: - from axolotl.cli.shard import do_cli - - do_cli(config=config, **kwargs) - - @cli.command() @click.argument("config", type=click.Path(exists=True, path_type=str)) @click.option( diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 6e461bf53..2cada4862 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -49,9 +49,9 @@ def do_merge_lora(*, cfg: DictDefault) -> None: def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: """ - Parses `axolotl` config, training-specific CLI args, and calls `do_merge_lora` as a subroutine. - Note that various config values will be overwritten to allow the LoRA merge logic to work - as expected (`load_in_8bit=False`, `load_in4bit=False`, `flash_attention=False`, etc.). + Parses `axolotl` config, CLI args, and calls `do_merge_lora`. Note that various + config values will be overwritten to allow the LoRA merge logic to work as expected + (`load_in_8bit=False`, `load_in4bit=False`, `flash_attention=False`, etc.). Args: config: Path to `axolotl` config YAML file. diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py index 0ecf7e70f..9b297b872 100644 --- a/src/axolotl/cli/merge_sharded_fsdp_weights.py +++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py @@ -32,9 +32,7 @@ LOG = logging.getLogger(__name__) class BFloat16CastPlanner(_EmptyStateDictLoadPlanner): - """ - A custom planner to cast tensors to bfloat16 on the fly during loading. - """ + """A custom planner to cast tensors to bfloat16 on the fly during loading.""" def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument tensor.copy_(tensor.to(torch.bfloat16)) @@ -45,11 +43,19 @@ def _distributed_checkpoint_to_merged_weights( save_path: str, safe_serialization: bool = False, max_shard_size: str = "5GB", -): +) -> Path: """ - Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save` + Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`. Will + save under `save_path` as either `model.safetensors` or `pytorch_model.bin`. - Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`. + Args: + checkpoint_dir: Directory where distributed checkpoint is saved. + save_path: Path to save model to. + safe_serialization: Whether to save in safetensors format. + max_shard_size: Max size of model shards to save. + + Returns: + Path where model is saved. """ state_dict: Dict = {} @@ -79,6 +85,7 @@ def _distributed_checkpoint_to_merged_weights( state_dict_split = split_torch_state_dict_into_shards( state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size ) + # Save index if sharded index = None if state_dict_split.is_sharded: @@ -135,6 +142,9 @@ def merge_fsdp_weights( Whether to save the merged weights with safetensors (recommended). remove_checkpoint_dir (`bool`, *optional*, defaults to `False`): Whether to remove the checkpoint directory after merging. + + Raises: + ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist. """ checkpoint_dir_ = Path(checkpoint_dir) from accelerate.state import PartialState @@ -178,18 +188,21 @@ def merge_fsdp_weights( def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): + """ + Parses `axolotl` config, CLI args, and calls `merge_fsdp_weights`. + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + """ # pylint: disable=duplicate-code print_axolotl_text_art() - parser = transformers.HfArgumentParser((TrainerCliArgs)) + parser = transformers.HfArgumentParser(TrainerCliArgs) parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) parsed_cli_args.merge_lora = True - - parsed_cfg = load_cfg( - config, - **kwargs, - ) + parsed_cfg = load_cfg(config, **kwargs) fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0" merge_fsdp_weights( diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index a12be5d27..924695277 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -15,9 +15,9 @@ from transformers import AutoModelForCausalLM from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.config import load_cfg -from axolotl.cli.datasets import load_datasets, load_rl_datasets from axolotl.common.cli import PreprocessCliArgs from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH +from axolotl.common.datasets import load_datasets, load_rl_datasets from axolotl.utils.dict import DictDefault from axolotl.utils.trainer import disable_datasets_caching @@ -77,7 +77,7 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: """ - Parses `axolotl` config, preprocessing-specific CLI args, and calls `do_preprocess` as a subroutine. + Parses `axolotl` config, CLI args, and calls `do_preprocess`. Args: config: Path to `axolotl` config YAML file. diff --git a/src/axolotl/cli/shard.py b/src/axolotl/cli/shard.py deleted file mode 100644 index ad680e4b3..000000000 --- a/src/axolotl/cli/shard.py +++ /dev/null @@ -1,33 +0,0 @@ -"""CLI to shard a trained model into 10GiB chunks.""" - -import logging -from pathlib import Path -from typing import Union - -import fire -from dotenv import load_dotenv - -from axolotl.cli.art import print_axolotl_text_art -from axolotl.cli.config import load_cfg -from axolotl.common.cli import load_model_and_tokenizer -from axolotl.utils.dict import DictDefault - -LOG = logging.getLogger(__name__) - - -def shard(*, cfg: DictDefault): - model, _ = load_model_and_tokenizer(cfg=cfg) - safe_serialization = cfg.save_safetensors is True - LOG.debug("Re-saving model w/ sharding") - model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) - - -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): - print_axolotl_text_art() - parsed_cfg = load_cfg(config, **kwargs) - shard(cfg=parsed_cfg) - - -if __name__ == "__main__": - load_dotenv() - fire.Fire(do_cli) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 22fa143f0..9bd04919e 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -11,8 +11,8 @@ from transformers.hf_argparser import HfArgumentParser from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.config import load_cfg -from axolotl.cli.datasets import load_datasets, load_rl_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets, load_rl_datasets from axolotl.integrations.base import PluginManager from axolotl.train import train from axolotl.utils.dict import DictDefault @@ -22,8 +22,8 @@ LOG = logging.getLogger(__name__) def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: """ - Trains transformer model by first loading the dataset(s) specified in the `axolotl` - config, and then calling `axolotl.train.train` as a subroutine. Also runs the plugin + Trains a `transformers` model by first loading the dataset(s) specified in the + `axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin manager's `post_train_unload` once training completes. Args: @@ -50,7 +50,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: """ - Parses `axolotl` config, training-specific CLI args, and calls `do_train` as a subroutine. + Parses `axolotl` config, CLI args, and calls `do_train`. Args: config: Path to `axolotl` config YAML file. diff --git a/src/axolotl/cli/datasets.py b/src/axolotl/common/datasets.py similarity index 55% rename from src/axolotl/cli/datasets.py rename to src/axolotl/common/datasets.py index 5e44e003b..a4d568c94 100644 --- a/src/axolotl/cli/datasets.py +++ b/src/axolotl/common/datasets.py @@ -1,11 +1,14 @@ """Dataset loading utilities.""" + import logging import math import random -from typing import Union +from dataclasses import dataclass +from typing import Optional, Union + +from datasets import Dataset from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs -from axolotl.train import TrainDatasetMeta from axolotl.utils.data import prepare_dataset from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.dict import DictDefault @@ -15,11 +18,48 @@ from axolotl.utils.tokenization import check_dataset_labels LOG = logging.getLogger(__name__) +@dataclass +class TrainDatasetMeta: + """Dataclass with fields for training and validation datasets and metadata.""" + + train_dataset: Dataset + eval_dataset: Optional[Dataset] = None + total_num_steps: Optional[int] = None + + +def sample_dataset(dataset: Dataset, num_samples: int): + """ + Randomly sample `num_samples` samples from `dataset`. + + Args: + dataset: Dataset. + num_samples: Number of samples to return. + + Returns: + Random sample (with replacement) of examples in `dataset`. + """ + return dataset.select( + [random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec + ) + + def load_datasets( *, cfg: DictDefault, cli_args: Union[PreprocessCliArgs, TrainerCliArgs], ) -> TrainDatasetMeta: + """ + Loads one or more training or evaluation datasets, calling + `axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: Command-specific CLI arguments. + + Returns: + Dataclass with fields for training and evaluation datasets and the computed + `total_num_steps`. + """ tokenizer = load_tokenizer(cfg) processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None @@ -36,13 +76,10 @@ def load_datasets( or int(cli_args.debug_num_examples) > 0 ): LOG.info("check_dataset_labels...") + + train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples) check_dataset_labels( - train_dataset.select( - [ - random.randrange(0, len(train_dataset) - 1) # nosec - for _ in range(cli_args.debug_num_examples) - ] - ), + train_samples, tokenizer, num_examples=cli_args.debug_num_examples, text_only=cli_args.debug_text_only, @@ -66,6 +103,19 @@ def load_rl_datasets( PreprocessCliArgs, TrainerCliArgs ], # pylint: disable=unused-argument ) -> TrainDatasetMeta: + """ + Loads one or more training or evaluation datasets for RL training, calling + `axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug + information. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: Command-specific CLI arguments. + + Returns: + Dataclass with fields for training and evaluation datasets and the computed + `total_num_steps`. + """ train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg) total_num_steps = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) @@ -75,13 +125,9 @@ def load_rl_datasets( LOG.info("check_dataset_labels...") tokenizer = load_tokenizer(cfg) + train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples) check_dataset_labels( - train_dataset.select( - [ - random.randrange(0, len(train_dataset) - 1) # nosec - for _ in range(cli_args.debug_num_examples) - ] - ), + train_samples, tokenizer, num_examples=cli_args.debug_num_examples, text_only=cli_args.debug_text_only, diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 20795f4e7..b901c2a97 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -5,20 +5,19 @@ import os import signal import sys import weakref -from dataclasses import dataclass from pathlib import Path -from typing import Optional, Tuple, Union +from typing import Tuple, Union import torch import transformers.modelcard from accelerate.logging import get_logger from accelerate.utils import save_fsdp_model -from datasets import Dataset from peft import PeftModel from pkg_resources import get_distribution # type: ignore from transformers import PreTrainedModel, PreTrainedTokenizer from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +from axolotl.common.datasets import TrainDatasetMeta from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) @@ -41,15 +40,6 @@ configure_logging() LOG = get_logger(__name__) -@dataclass -class TrainDatasetMeta: - """Dataclass with fields for training and validation datasets and metadata.""" - - train_dataset: Dataset - eval_dataset: Optional[Dataset] = None - total_num_steps: Optional[int] = None - - def train( *, cfg: DictDefault, dataset_meta: TrainDatasetMeta ) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: diff --git a/tests/cli/test_cli_shard.py b/tests/cli/test_cli_shard.py deleted file mode 100644 index 505a2a737..000000000 --- a/tests/cli/test_cli_shard.py +++ /dev/null @@ -1,76 +0,0 @@ -"""pytest tests for axolotl CLI shard command.""" -# pylint: disable=duplicate-code -from unittest.mock import patch - -from axolotl.cli.main import cli - - -def test_shard_with_accelerate(cli_runner, config_path): - """Test shard command with accelerate""" - with patch("subprocess.run") as mock: - result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"]) - - assert mock.called - assert mock.call_args.args[0] == [ - "accelerate", - "launch", - "-m", - "axolotl.cli.shard", - str(config_path), - "--debug-num-examples", - "0", - ] - assert mock.call_args.kwargs == {"check": True} - assert result.exit_code == 0 - - -def test_shard_no_accelerate(cli_runner, config_path): - """Test shard command without accelerate""" - with patch("axolotl.cli.shard.do_cli") as mock: - result = cli_runner.invoke(cli, ["shard", str(config_path), "--no-accelerate"]) - - assert mock.called - assert result.exit_code == 0 - - -def test_shard_with_model_dir(cli_runner, config_path, tmp_path): - """Test shard command with model_dir option""" - model_dir = tmp_path / "model" - model_dir.mkdir() - - with patch("axolotl.cli.shard.do_cli") as mock: - result = cli_runner.invoke( - cli, - [ - "shard", - str(config_path), - "--no-accelerate", - "--model-dir", - str(model_dir), - ], - catch_exceptions=False, - ) - - assert mock.called - assert mock.call_args.kwargs["config"] == str(config_path) - assert mock.call_args.kwargs["model_dir"] == str(model_dir) - assert result.exit_code == 0 - - -def test_shard_with_save_dir(cli_runner, config_path): - with patch("axolotl.cli.shard.do_cli") as mock: - result = cli_runner.invoke( - cli, - [ - "shard", - str(config_path), - "--no-accelerate", - "--save-dir", - "/path/to/save", - ], - ) - - assert mock.called - assert mock.call_args.kwargs["config"] == str(config_path) - assert mock.call_args.kwargs["save_dir"] == "/path/to/save" - assert result.exit_code == 0 diff --git a/tests/e2e/integrations/test_cut_cross_entropy.py b/tests/e2e/integrations/test_cut_cross_entropy.py index c976ef5b7..3205a479c 100644 --- a/tests/e2e/integrations/test_cut_cross_entropy.py +++ b/tests/e2e/integrations/test_cut_cross_entropy.py @@ -4,8 +4,8 @@ Simple end-to-end test for Cut Cross Entropy integration import pytest -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils import get_pytorch_version from axolotl.utils.config import normalize_config, prepare_plugins @@ -64,9 +64,9 @@ class TestCutCrossEntropyIntegration: major, minor, _ = get_pytorch_version() if (major, minor) < (2, 4): with pytest.raises(ImportError): - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) else: - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @pytest.mark.parametrize( @@ -92,7 +92,7 @@ class TestCutCrossEntropyIntegration: major, minor, _ = get_pytorch_version() if (major, minor) < (2, 4): with pytest.raises(ImportError): - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) else: - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py index 6dc3108fb..11f5fa9f6 100644 --- a/tests/e2e/integrations/test_liger.py +++ b/tests/e2e/integrations/test_liger.py @@ -6,6 +6,7 @@ from e2e.utils import require_torch_2_4_1 from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.dict import DictDefault @@ -60,7 +61,7 @@ class LigerIntegrationTestCase: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @require_torch_2_4_1 @@ -105,5 +106,5 @@ class LigerIntegrationTestCase: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py index 3cbe71450..e0b9a453e 100644 --- a/tests/e2e/patched/test_4d_multipack_llama.py +++ b/tests/e2e/patched/test_4d_multipack_llama.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -65,7 +65,7 @@ class Test4dMultipackLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -109,5 +109,5 @@ class Test4dMultipackLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py index 70de4cd4b..80236d3e8 100644 --- a/tests/e2e/patched/test_fa_xentropy.py +++ b/tests/e2e/patched/test_fa_xentropy.py @@ -8,8 +8,8 @@ import os import pytest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -80,7 +80,7 @@ class TestFAXentropyLlama: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( diff --git a/tests/e2e/patched/test_falcon_samplepack.py b/tests/e2e/patched/test_falcon_samplepack.py index b2cc420fa..79ef12996 100644 --- a/tests/e2e/patched/test_falcon_samplepack.py +++ b/tests/e2e/patched/test_falcon_samplepack.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -67,7 +67,7 @@ class TestFalconPatched(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -107,5 +107,5 @@ class TestFalconPatched(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py index 7968274ba..817f7ec61 100644 --- a/tests/e2e/patched/test_fused_llama.py +++ b/tests/e2e/patched/test_fused_llama.py @@ -9,8 +9,8 @@ import unittest import pytest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -71,5 +71,5 @@ class TestFusedLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py index 50cd0b8a0..147bca788 100644 --- a/tests/e2e/patched/test_llama_s2_attention.py +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -8,8 +8,8 @@ import unittest import pytest -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -69,7 +69,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -109,5 +109,5 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py index 8cf6c2e57..f321dbddd 100644 --- a/tests/e2e/patched/test_lora_llama_multipack.py +++ b/tests/e2e/patched/test_lora_llama_multipack.py @@ -9,8 +9,8 @@ import unittest import pytest from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -74,7 +74,7 @@ class TestLoraLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available") @@ -124,5 +124,5 @@ class TestLoraLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py index 1a8ce8662..7d1d4b94e 100644 --- a/tests/e2e/patched/test_mistral_samplepack.py +++ b/tests/e2e/patched/test_mistral_samplepack.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -67,7 +67,7 @@ class TestMistral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -108,5 +108,5 @@ class TestMistral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index 97400f598..4f960a452 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -64,7 +64,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -102,7 +102,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( "MixtralFlashAttention2" in model.model.layers[0].self_attn.__class__.__name__ diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py index 170c37fd6..78b01be64 100644 --- a/tests/e2e/patched/test_model_patches.py +++ b/tests/e2e/patched/test_model_patches.py @@ -6,7 +6,6 @@ import unittest import transformers -from axolotl.common.cli import TrainerCliArgs from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer @@ -49,9 +48,8 @@ class TestModelPatches(unittest.TestCase): } ) normalize_config(cfg) - cli_args = TrainerCliArgs() tokenizer = load_tokenizer(cfg) - model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) + model, _ = load_model(cfg, tokenizer, inference=False) assert ( "MixtralFlashAttention2" @@ -87,9 +85,8 @@ class TestModelPatches(unittest.TestCase): } ) normalize_config(cfg) - cli_args = TrainerCliArgs() tokenizer = load_tokenizer(cfg) - load_model(cfg, tokenizer, inference=cli_args.inference) + load_model(cfg, tokenizer, inference=False) assert ( "torch.jit" diff --git a/tests/e2e/patched/test_phi_multipack.py b/tests/e2e/patched/test_phi_multipack.py index 7e8b549f9..1c4104818 100644 --- a/tests/e2e/patched/test_phi_multipack.py +++ b/tests/e2e/patched/test_phi_multipack.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -67,7 +67,7 @@ class TestPhiMultipack(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -118,5 +118,5 @@ class TestPhiMultipack(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py index 3f9551dd4..2523cdafd 100644 --- a/tests/e2e/patched/test_resume.py +++ b/tests/e2e/patched/test_resume.py @@ -9,8 +9,8 @@ import subprocess from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -71,7 +71,7 @@ class TestResumeLlama: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) resume_cfg = cfg | DictDefault( { @@ -81,7 +81,7 @@ class TestResumeLlama: normalize_config(resume_cfg) cli_args = TrainerCliArgs() - train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=resume_cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) tb_log_path_1 = most_recent_subdir(temp_dir + "/runs") diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py index a083a665f..857bf930e 100644 --- a/tests/e2e/patched/test_unsloth_qlora.py +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -6,8 +6,8 @@ import os import pytest -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -75,7 +75,7 @@ class TestUnslothQLoRA: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( @@ -125,7 +125,7 @@ class TestUnslothQLoRA: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( @@ -180,7 +180,7 @@ class TestUnslothQLoRA: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 6a17e2c44..3a6c00ab0 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -9,8 +9,8 @@ from pathlib import Path import pytest -from axolotl.cli.datasets import load_rl_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_rl_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -67,7 +67,7 @@ class TestDPOLlamaLora(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir @@ -112,7 +112,7 @@ class TestDPOLlamaLora(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir @@ -157,7 +157,7 @@ class TestDPOLlamaLora(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @pytest.mark.skip("kto_pair no longer supported in trl") @@ -202,7 +202,7 @@ class TestDPOLlamaLora(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir @@ -246,7 +246,7 @@ class TestDPOLlamaLora(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir @@ -293,7 +293,7 @@ class TestDPOLlamaLora(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @pytest.mark.skip(reason="Fix the implementation") @@ -357,5 +357,5 @@ class TestDPOLlamaLora(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) diff --git a/tests/e2e/test_embeddings_lr.py b/tests/e2e/test_embeddings_lr.py index 430dcfb51..60470b898 100644 --- a/tests/e2e/test_embeddings_lr.py +++ b/tests/e2e/test_embeddings_lr.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -60,7 +60,7 @@ class TestEmbeddingsLrScale(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( @@ -104,7 +104,7 @@ class TestEmbeddingsLrScale(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( diff --git a/tests/e2e/test_falcon.py b/tests/e2e/test_falcon.py index e82c1a709..6d31d4caf 100644 --- a/tests/e2e/test_falcon.py +++ b/tests/e2e/test_falcon.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -69,7 +69,7 @@ class TestFalcon(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -122,7 +122,7 @@ class TestFalcon(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -161,5 +161,5 @@ class TestFalcon(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index cb9eb6d44..5a2c1e552 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -7,8 +7,8 @@ import os from e2e.utils import check_model_output_exists -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -60,7 +60,7 @@ class TestLlama: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) def test_fix_untrained_tokens(self, temp_dir): @@ -103,7 +103,7 @@ class TestLlama: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) def test_batch_flattening(self, temp_dir): @@ -142,5 +142,5 @@ class TestLlama: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index 8faf28b96..be2c7b50f 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -62,5 +62,5 @@ class TestPretrainLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py index 5194f446a..fed4b24ac 100644 --- a/tests/e2e/test_llama_vision.py +++ b/tests/e2e/test_llama_vision.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -66,7 +66,7 @@ class TestLlamaVision(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -111,5 +111,5 @@ class TestLlamaVision(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index 9b27ce8d0..ba0696c44 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -63,5 +63,5 @@ class TestLoraLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py index a61d3790a..33732602e 100644 --- a/tests/e2e/test_mamba.py +++ b/tests/e2e/test_mamba.py @@ -8,8 +8,8 @@ import unittest import pytest -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -63,5 +63,10 @@ class TestMamba(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) +<<<<<<< HEAD train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) +======= + train(cfg=cfg, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "pytorch_model.bin").exists() +>>>>>>> 2a421127 (continued cleanup and documentation) diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py index 359a490f8..66fedd6a9 100644 --- a/tests/e2e/test_mistral.py +++ b/tests/e2e/test_mistral.py @@ -8,8 +8,8 @@ import unittest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -67,7 +67,7 @@ class TestMistral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -110,5 +110,5 @@ class TestMistral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py index 103305925..c984c3fa0 100644 --- a/tests/e2e/test_mixtral.py +++ b/tests/e2e/test_mixtral.py @@ -9,8 +9,8 @@ import unittest import torch from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -73,7 +73,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 @@ -127,7 +127,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 @@ -184,7 +184,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 @@ -241,7 +241,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 @@ -285,5 +285,5 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index d0805bca3..76da6bb6a 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -63,7 +63,7 @@ class TestCustomOptimizers(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -107,7 +107,7 @@ class TestCustomOptimizers(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -143,5 +143,5 @@ class TestCustomOptimizers(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py index 748f818a7..71e57509a 100644 --- a/tests/e2e/test_packing_loss.py +++ b/tests/e2e/test_packing_loss.py @@ -8,8 +8,8 @@ import unittest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -63,7 +63,7 @@ class TestPackedLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py index e5d0e86e7..a94af7d30 100644 --- a/tests/e2e/test_phi.py +++ b/tests/e2e/test_phi.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -65,7 +65,7 @@ class TestPhi(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -114,5 +114,5 @@ class TestPhi(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_relora_llama.py b/tests/e2e/test_relora_llama.py index 2bb87c15b..8e02d8a06 100644 --- a/tests/e2e/test_relora_llama.py +++ b/tests/e2e/test_relora_llama.py @@ -7,8 +7,8 @@ import os import unittest from pathlib import Path -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -77,7 +77,7 @@ class TestReLoraLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg) assert ( Path(temp_dir) / "checkpoint-100/relora/model.safetensors" diff --git a/tests/e2e/test_reward_model_llama.py b/tests/e2e/test_reward_model_llama.py index 9617729bc..469b9cf0d 100644 --- a/tests/e2e/test_reward_model_llama.py +++ b/tests/e2e/test_reward_model_llama.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -69,5 +69,5 @@ class TestRewardModelLoraLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg)