continued cleanup and documentation

This commit is contained in:
Dan Saunders
2025-01-08 19:15:03 +00:00
parent 929ee15cc3
commit 6e72baf287
42 changed files with 280 additions and 300 deletions

View File

@@ -34,7 +34,12 @@ def print_dep_versions():
def print_legacy_axolotl_text_art(suffix=None): def print_legacy_axolotl_text_art(suffix=None):
"""Prints axolotl ASCII art and dependency versions.""" """
Prints axolotl ASCII art and dependency versions.
Args:
suffix: Text to append to ASCII art text.
"""
font = "nancyj" font = "nancyj"
ascii_text = " axolotl" ascii_text = " axolotl"
if suffix: if suffix:

View File

@@ -28,7 +28,24 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def check_remote_config(config: Union[str, Path]): def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
"""
First, determines if the passed config is a valid HTTPS URL. Then, attempts to query
for it and parse its content, first as JSON, then as YAML (YAML is preferred).
Finally, the parsed content is written to a local file and its path is returned.
Args:
config: HTTPS URL to a YAML or JSON file.
Returns:
Either the original `config` if it's not a valid HTTPS URL, or the path to the
downloaded remote config.
Raises:
ValueError: If the remote configuration is neither valid JSON or YAML.
RuntimeError: If some request-related exception occurs from the file download.
Exception: Catch-all for any other exception.
"""
# Check if the config is a valid HTTPS URL to a .yml or .yaml file # Check if the config is a valid HTTPS URL to a .yml or .yaml file
if not (isinstance(config, str) and config.startswith("https://")): if not (isinstance(config, str) and config.startswith("https://")):
return config # Return the original value if it's not a valid URL return config # Return the original value if it's not a valid URL
@@ -42,9 +59,12 @@ def check_remote_config(config: Union[str, Path]):
content = response.content content = response.content
try: try:
# Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML # Try parsing as JSON first to catch cases where JSON content is mistakenly
# considered YAML.
json.loads(content) json.loads(content)
# Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link
# Log a warning but do not raise an error; JSON is technically valid YAML.
# This can happen when you forget to point to a raw GitHub link.
LOG.warning( LOG.warning(
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended." f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
) )
@@ -75,7 +95,22 @@ def check_remote_config(config: Union[str, Path]):
def choose_config(path: Path): def choose_config(path: Path):
yaml_files = list(path.glob("*.yml")) """
Helper method for choosing a `axolotl` config YAML file (considering only files
ending with `.yml` or `.yaml`). If more than one config file exists in the passed
`path`, the user is prompted to choose one.
Args:
path: Directory in which config file(s) are stored.
Returns:
Path to either (1) the sole YAML file, or (2) if more than one YAML files exist,
the user-selected YAML file.
Raises:
ValueError: If no YAML files are found in the given `path`.
"""
yaml_files = list(path.glob("*.yml")) + list(path.glob("*.yaml"))
if not yaml_files: if not yaml_files:
raise ValueError( raise ValueError(
@@ -104,11 +139,13 @@ def choose_config(path: Path):
return chosen_file return chosen_file
def prepare_plugins(cfg): def prepare_plugins(cfg: DictDefault):
"""
Prepare the plugins for the configuration
""" """
Registers the plugins for the given configuration.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
if cfg.get("plugins"): if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]: for plugin_name in cfg["plugins"]:
@@ -116,15 +153,27 @@ def prepare_plugins(cfg):
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
"""
Loads the `axolotl` configuration stored at `config`, validates it, and performs
various setup.
Args:
config: Path (local or remote) to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
Returns:
`DictDefault` mapping configuration keys to values.
"""
config = check_remote_config(config) config = check_remote_config(config)
if Path(config).is_dir(): if Path(config).is_dir():
config = choose_config(Path(config)) config = choose_config(Path(config))
# load the config from the yaml file # Load the config from the yaml file
with open(config, encoding="utf-8") as file: with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file)) cfg: DictDefault = DictDefault(yaml.safe_load(file))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value # If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
cfg_keys = cfg.keys() cfg_keys = cfg.keys()
for k, _ in kwargs.items(): for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already # if not strict, allow writing to cfg even if it's not in the yml already

View File

@@ -11,14 +11,24 @@ from transformers.hf_argparser import HfArgumentParser
from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.cli.datasets import load_datasets, load_rl_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets, load_rl_datasets
from axolotl.evaluate import evaluate from axolotl.evaluate import evaluate
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def do_evaluate(cfg, cli_args) -> None: def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
"""
Evaluates a `transformers` model by first loading the dataset(s) specified in the
`axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes
evaluation metrics on the given dataset(s) and writes them to disk.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: CLI arguments.
"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art() print_axolotl_text_art()
check_accelerate_default_config() check_accelerate_default_config()
@@ -33,6 +43,13 @@ def do_evaluate(cfg, cli_args) -> None:
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_evaluate`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs) parser = HfArgumentParser(TrainerCliArgs)

View File

@@ -52,7 +52,7 @@ def do_inference(
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Training-specific CLI arguments. cli_args: Inference-specific CLI arguments.
""" """
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True) model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
prompter = cli_args.prompter prompter = cli_args.prompter
@@ -145,7 +145,7 @@ def do_inference_gradio(
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Training-specific CLI arguments. cli_args: Inference-specific CLI arguments.
""" """
import gradio as gr import gradio as gr
@@ -246,8 +246,7 @@ def do_cli(
config: Union[Path, str] = Path("examples/"), gradio: bool = False, **kwargs config: Union[Path, str] = Path("examples/"), gradio: bool = False, **kwargs
) -> None: ) -> None:
""" """
Parses axolotl config, training-specific CLI args, and calls `do_inference` or Parses axolotl config, CLI args, and calls `do_inference` or `do_inference_gradio`.
`do_inference_gradio` as a subroutine.
Args: Args:
config: Path to `axolotl` config YAML file. config: Path to `axolotl` config YAML file.

View File

@@ -28,7 +28,7 @@ def cli():
@click.argument("config", type=click.Path(exists=True, path_type=str)) @click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(PreprocessCliArgs) @add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig) @add_options_from_config(AxolotlInputConfig)
def preprocess(config: str, **kwargs): def preprocess(config: str, **kwargs) -> None:
""" """
Preprocess datasets before training. Preprocess datasets before training.
@@ -148,39 +148,6 @@ def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
do_cli(config=config, gradio=gradio, **kwargs) do_cli(config=config, gradio=gradio, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=False,
help="Use accelerate launch for multi-GPU operations",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def shard(config: str, accelerate: bool, **kwargs) -> None:
"""
Shard model weights into chunks.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.shard"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.shard import do_cli
do_cli(config=config, **kwargs)
@cli.command() @cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str)) @click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option( @click.option(

View File

@@ -49,9 +49,9 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
""" """
Parses `axolotl` config, training-specific CLI args, and calls `do_merge_lora` as a subroutine. Parses `axolotl` config, CLI args, and calls `do_merge_lora`. Note that various
Note that various config values will be overwritten to allow the LoRA merge logic to work config values will be overwritten to allow the LoRA merge logic to work as expected
as expected (`load_in_8bit=False`, `load_in4bit=False`, `flash_attention=False`, etc.). (`load_in_8bit=False`, `load_in4bit=False`, `flash_attention=False`, etc.).
Args: Args:
config: Path to `axolotl` config YAML file. config: Path to `axolotl` config YAML file.

View File

@@ -32,9 +32,7 @@ LOG = logging.getLogger(__name__)
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner): class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
""" """A custom planner to cast tensors to bfloat16 on the fly during loading."""
A custom planner to cast tensors to bfloat16 on the fly during loading.
"""
def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument
tensor.copy_(tensor.to(torch.bfloat16)) tensor.copy_(tensor.to(torch.bfloat16))
@@ -45,11 +43,19 @@ def _distributed_checkpoint_to_merged_weights(
save_path: str, save_path: str,
safe_serialization: bool = False, safe_serialization: bool = False,
max_shard_size: str = "5GB", max_shard_size: str = "5GB",
): ) -> Path:
""" """
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save` Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`. Will
save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`. Args:
checkpoint_dir: Directory where distributed checkpoint is saved.
save_path: Path to save model to.
safe_serialization: Whether to save in safetensors format.
max_shard_size: Max size of model shards to save.
Returns:
Path where model is saved.
""" """
state_dict: Dict = {} state_dict: Dict = {}
@@ -79,6 +85,7 @@ def _distributed_checkpoint_to_merged_weights(
state_dict_split = split_torch_state_dict_into_shards( state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
) )
# Save index if sharded # Save index if sharded
index = None index = None
if state_dict_split.is_sharded: if state_dict_split.is_sharded:
@@ -135,6 +142,9 @@ def merge_fsdp_weights(
Whether to save the merged weights with safetensors (recommended). Whether to save the merged weights with safetensors (recommended).
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`): remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
Whether to remove the checkpoint directory after merging. Whether to remove the checkpoint directory after merging.
Raises:
ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist.
""" """
checkpoint_dir_ = Path(checkpoint_dir) checkpoint_dir_ = Path(checkpoint_dir)
from accelerate.state import PartialState from accelerate.state import PartialState
@@ -178,18 +188,21 @@ def merge_fsdp_weights(
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
"""
Parses `axolotl` config, CLI args, and calls `merge_fsdp_weights`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art() print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs)) parser = transformers.HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True
) )
parsed_cli_args.merge_lora = True parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg = load_cfg(
config,
**kwargs,
)
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0" fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
merge_fsdp_weights( merge_fsdp_weights(

View File

@@ -15,9 +15,9 @@ from transformers import AutoModelForCausalLM
from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.cli.datasets import load_datasets, load_rl_datasets
from axolotl.common.cli import PreprocessCliArgs from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_rl_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import disable_datasets_caching from axolotl.utils.trainer import disable_datasets_caching
@@ -77,7 +77,7 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
""" """
Parses `axolotl` config, preprocessing-specific CLI args, and calls `do_preprocess` as a subroutine. Parses `axolotl` config, CLI args, and calls `do_preprocess`.
Args: Args:
config: Path to `axolotl` config YAML file. config: Path to `axolotl` config YAML file.

View File

@@ -1,33 +0,0 @@
"""CLI to shard a trained model into 10GiB chunks."""
import logging
from pathlib import Path
from typing import Union
import fire
from dotenv import load_dotenv
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.common.cli import load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def shard(*, cfg: DictDefault):
model, _ = load_model_and_tokenizer(cfg=cfg)
safe_serialization = cfg.save_safetensors is True
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
shard(cfg=parsed_cfg)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -11,8 +11,8 @@ from transformers.hf_argparser import HfArgumentParser
from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.cli.datasets import load_datasets, load_rl_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets, load_rl_datasets
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.train import train from axolotl.train import train
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -22,8 +22,8 @@ LOG = logging.getLogger(__name__)
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
""" """
Trains transformer model by first loading the dataset(s) specified in the `axolotl` Trains a `transformers` model by first loading the dataset(s) specified in the
config, and then calling `axolotl.train.train` as a subroutine. Also runs the plugin `axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
manager's `post_train_unload` once training completes. manager's `post_train_unload` once training completes.
Args: Args:
@@ -50,7 +50,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
""" """
Parses `axolotl` config, training-specific CLI args, and calls `do_train` as a subroutine. Parses `axolotl` config, CLI args, and calls `do_train`.
Args: Args:
config: Path to `axolotl` config YAML file. config: Path to `axolotl` config YAML file.

View File

@@ -1,11 +1,14 @@
"""Dataset loading utilities.""" """Dataset loading utilities."""
import logging import logging
import math import math
import random import random
from typing import Union from dataclasses import dataclass
from typing import Optional, Union
from datasets import Dataset
from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs
from axolotl.train import TrainDatasetMeta
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -15,11 +18,48 @@ from axolotl.utils.tokenization import check_dataset_labels
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@dataclass
class TrainDatasetMeta:
"""Dataclass with fields for training and validation datasets and metadata."""
train_dataset: Dataset
eval_dataset: Optional[Dataset] = None
total_num_steps: Optional[int] = None
def sample_dataset(dataset: Dataset, num_samples: int):
"""
Randomly sample `num_samples` samples from `dataset`.
Args:
dataset: Dataset.
num_samples: Number of samples to return.
Returns:
Random sample (with replacement) of examples in `dataset`.
"""
return dataset.select(
[random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec
)
def load_datasets( def load_datasets(
*, *,
cfg: DictDefault, cfg: DictDefault,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs], cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets, calling
`axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Command-specific CLI arguments.
Returns:
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
"""
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
@@ -36,13 +76,10 @@ def load_datasets(
or int(cli_args.debug_num_examples) > 0 or int(cli_args.debug_num_examples) > 0
): ):
LOG.info("check_dataset_labels...") LOG.info("check_dataset_labels...")
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels( check_dataset_labels(
train_dataset.select( train_samples,
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer, tokenizer,
num_examples=cli_args.debug_num_examples, num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only, text_only=cli_args.debug_text_only,
@@ -66,6 +103,19 @@ def load_rl_datasets(
PreprocessCliArgs, TrainerCliArgs PreprocessCliArgs, TrainerCliArgs
], # pylint: disable=unused-argument ], # pylint: disable=unused-argument
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets for RL training, calling
`axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug
information.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Command-specific CLI arguments.
Returns:
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
"""
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg) train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
total_num_steps = int( total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
@@ -75,13 +125,9 @@ def load_rl_datasets(
LOG.info("check_dataset_labels...") LOG.info("check_dataset_labels...")
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels( check_dataset_labels(
train_dataset.select( train_samples,
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer, tokenizer,
num_examples=cli_args.debug_num_examples, num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only, text_only=cli_args.debug_text_only,

View File

@@ -5,20 +5,19 @@ import os
import signal import signal
import sys import sys
import weakref import weakref
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple, Union from typing import Tuple, Union
import torch import torch
import transformers.modelcard import transformers.modelcard
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model from accelerate.utils import save_fsdp_model
from datasets import Dataset
from peft import PeftModel from peft import PeftModel
from pkg_resources import get_distribution # type: ignore from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
fix_untrained_tokens, fix_untrained_tokens,
) )
@@ -41,15 +40,6 @@ configure_logging()
LOG = get_logger(__name__) LOG = get_logger(__name__)
@dataclass
class TrainDatasetMeta:
"""Dataclass with fields for training and validation datasets and metadata."""
train_dataset: Dataset
eval_dataset: Optional[Dataset] = None
total_num_steps: Optional[int] = None
def train( def train(
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta *, cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: ) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:

View File

@@ -1,76 +0,0 @@
"""pytest tests for axolotl CLI shard command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
def test_shard_with_accelerate(cli_runner, config_path):
"""Test shard command with accelerate"""
with patch("subprocess.run") as mock:
result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"])
assert mock.called
assert mock.call_args.args[0] == [
"accelerate",
"launch",
"-m",
"axolotl.cli.shard",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0
def test_shard_no_accelerate(cli_runner, config_path):
"""Test shard command without accelerate"""
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(cli, ["shard", str(config_path), "--no-accelerate"])
assert mock.called
assert result.exit_code == 0
def test_shard_with_model_dir(cli_runner, config_path, tmp_path):
"""Test shard command with model_dir option"""
model_dir = tmp_path / "model"
model_dir.mkdir()
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"shard",
str(config_path),
"--no-accelerate",
"--model-dir",
str(model_dir),
],
catch_exceptions=False,
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
assert result.exit_code == 0
def test_shard_with_save_dir(cli_runner, config_path):
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"shard",
str(config_path),
"--no-accelerate",
"--save-dir",
"/path/to/save",
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["save_dir"] == "/path/to/save"
assert result.exit_code == 0

View File

@@ -4,8 +4,8 @@ Simple end-to-end test for Cut Cross Entropy integration
import pytest import pytest
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils import get_pytorch_version from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.config import normalize_config, prepare_plugins
@@ -64,9 +64,9 @@ class TestCutCrossEntropyIntegration:
major, minor, _ = get_pytorch_version() major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4): if (major, minor) < (2, 4):
with pytest.raises(ImportError): with pytest.raises(ImportError):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
else: else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -92,7 +92,7 @@ class TestCutCrossEntropyIntegration:
major, minor, _ = get_pytorch_version() major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4): if (major, minor) < (2, 4):
with pytest.raises(ImportError): with pytest.raises(ImportError):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
else: else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,6 +6,7 @@ from e2e.utils import require_torch_2_4_1
from axolotl.cli.datasets import load_datasets from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -60,7 +61,7 @@ class LigerIntegrationTestCase:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@require_torch_2_4_1 @require_torch_2_4_1
@@ -105,5 +106,5 @@ class LigerIntegrationTestCase:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,8 +6,8 @@ import logging
import os import os
import unittest import unittest
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -65,7 +65,7 @@ class Test4dMultipackLlama(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@with_temp_dir @with_temp_dir
@@ -109,5 +109,5 @@ class Test4dMultipackLlama(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -8,8 +8,8 @@ import os
import pytest import pytest
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -80,7 +80,7 @@ class TestFAXentropyLlama:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
check_tensorboard( check_tensorboard(

View File

@@ -6,8 +6,8 @@ import logging
import os import os
import unittest import unittest
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -67,7 +67,7 @@ class TestFalconPatched(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@with_temp_dir @with_temp_dir
@@ -107,5 +107,5 @@ class TestFalconPatched(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -9,8 +9,8 @@ import unittest
import pytest import pytest
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -71,5 +71,5 @@ class TestFusedLlama(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -8,8 +8,8 @@ import unittest
import pytest import pytest
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -69,7 +69,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@with_temp_dir @with_temp_dir
@@ -109,5 +109,5 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -9,8 +9,8 @@ import unittest
import pytest import pytest
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -74,7 +74,7 @@ class TestLoraLlama(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available") @pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
@@ -124,5 +124,5 @@ class TestLoraLlama(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,8 +6,8 @@ import logging
import os import os
import unittest import unittest
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -67,7 +67,7 @@ class TestMistral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@with_temp_dir @with_temp_dir
@@ -108,5 +108,5 @@ class TestMistral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,8 +6,8 @@ import logging
import os import os
import unittest import unittest
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -64,7 +64,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@with_temp_dir @with_temp_dir
@@ -102,7 +102,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert ( assert (
"MixtralFlashAttention2" "MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__ in model.model.layers[0].self_attn.__class__.__name__

View File

@@ -6,7 +6,6 @@ import unittest
import transformers import transformers
from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
@@ -49,9 +48,8 @@ class TestModelPatches(unittest.TestCase):
} }
) )
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs()
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) model, _ = load_model(cfg, tokenizer, inference=False)
assert ( assert (
"MixtralFlashAttention2" "MixtralFlashAttention2"
@@ -87,9 +85,8 @@ class TestModelPatches(unittest.TestCase):
} }
) )
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs()
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=cli_args.inference) load_model(cfg, tokenizer, inference=False)
assert ( assert (
"torch.jit" "torch.jit"

View File

@@ -6,8 +6,8 @@ import logging
import os import os
import unittest import unittest
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -67,7 +67,7 @@ class TestPhiMultipack(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@with_temp_dir @with_temp_dir
@@ -118,5 +118,5 @@ class TestPhiMultipack(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -9,8 +9,8 @@ import subprocess
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -71,7 +71,7 @@ class TestResumeLlama:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
resume_cfg = cfg | DictDefault( resume_cfg = cfg | DictDefault(
{ {
@@ -81,7 +81,7 @@ class TestResumeLlama:
normalize_config(resume_cfg) normalize_config(resume_cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=resume_cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
tb_log_path_1 = most_recent_subdir(temp_dir + "/runs") tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")

View File

@@ -6,8 +6,8 @@ import os
import pytest import pytest
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -75,7 +75,7 @@ class TestUnslothQLoRA:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
check_tensorboard( check_tensorboard(
@@ -125,7 +125,7 @@ class TestUnslothQLoRA:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
check_tensorboard( check_tensorboard(
@@ -180,7 +180,7 @@ class TestUnslothQLoRA:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
check_tensorboard( check_tensorboard(

View File

@@ -9,8 +9,8 @@ from pathlib import Path
import pytest import pytest
from axolotl.cli.datasets import load_rl_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_rl_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -67,7 +67,7 @@ class TestDPOLlamaLora(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@with_temp_dir @with_temp_dir
@@ -112,7 +112,7 @@ class TestDPOLlamaLora(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@with_temp_dir @with_temp_dir
@@ -157,7 +157,7 @@ class TestDPOLlamaLora(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@pytest.mark.skip("kto_pair no longer supported in trl") @pytest.mark.skip("kto_pair no longer supported in trl")
@@ -202,7 +202,7 @@ class TestDPOLlamaLora(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@with_temp_dir @with_temp_dir
@@ -246,7 +246,7 @@ class TestDPOLlamaLora(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@with_temp_dir @with_temp_dir
@@ -293,7 +293,7 @@ class TestDPOLlamaLora(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@pytest.mark.skip(reason="Fix the implementation") @pytest.mark.skip(reason="Fix the implementation")
@@ -357,5 +357,5 @@ class TestDPOLlamaLora(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)

View File

@@ -6,8 +6,8 @@ import logging
import os import os
import unittest import unittest
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -60,7 +60,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
check_tensorboard( check_tensorboard(
@@ -104,7 +104,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
check_tensorboard( check_tensorboard(

View File

@@ -6,8 +6,8 @@ import logging
import os import os
import unittest import unittest
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -69,7 +69,7 @@ class TestFalcon(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@with_temp_dir @with_temp_dir
@@ -122,7 +122,7 @@ class TestFalcon(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@with_temp_dir @with_temp_dir
@@ -161,5 +161,5 @@ class TestFalcon(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -7,8 +7,8 @@ import os
from e2e.utils import check_model_output_exists from e2e.utils import check_model_output_exists
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -60,7 +60,7 @@ class TestLlama:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens(self, temp_dir): def test_fix_untrained_tokens(self, temp_dir):
@@ -103,7 +103,7 @@ class TestLlama:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
def test_batch_flattening(self, temp_dir): def test_batch_flattening(self, temp_dir):
@@ -142,5 +142,5 @@ class TestLlama:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,8 +6,8 @@ import logging
import os import os
import unittest import unittest
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -62,5 +62,5 @@ class TestPretrainLlama(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,8 +6,8 @@ import logging
import os import os
import unittest import unittest
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -66,7 +66,7 @@ class TestLlamaVision(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@with_temp_dir @with_temp_dir
@@ -111,5 +111,5 @@ class TestLlamaVision(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,8 +6,8 @@ import logging
import os import os
import unittest import unittest
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -63,5 +63,5 @@ class TestLoraLlama(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -8,8 +8,8 @@ import unittest
import pytest import pytest
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -63,5 +63,10 @@ class TestMamba(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
<<<<<<< HEAD
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
=======
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
>>>>>>> 2a421127 (continued cleanup and documentation)

View File

@@ -8,8 +8,8 @@ import unittest
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -67,7 +67,7 @@ class TestMistral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@with_temp_dir @with_temp_dir
@@ -110,5 +110,5 @@ class TestMistral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -9,8 +9,8 @@ import unittest
import torch import torch
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -73,7 +73,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert ( assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32 == torch.float32
@@ -127,7 +127,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert ( assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32 == torch.float32
@@ -184,7 +184,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert ( assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32 == torch.float32
@@ -241,7 +241,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert ( assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32 == torch.float32
@@ -285,5 +285,5 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -6,8 +6,8 @@ import logging
import os import os
import unittest import unittest
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -63,7 +63,7 @@ class TestCustomOptimizers(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@with_temp_dir @with_temp_dir
@@ -107,7 +107,7 @@ class TestCustomOptimizers(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@with_temp_dir @with_temp_dir
@@ -143,5 +143,5 @@ class TestCustomOptimizers(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -8,8 +8,8 @@ import unittest
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -63,7 +63,7 @@ class TestPackedLlama(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"

View File

@@ -6,8 +6,8 @@ import logging
import os import os
import unittest import unittest
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -65,7 +65,7 @@ class TestPhi(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@with_temp_dir @with_temp_dir
@@ -114,5 +114,5 @@ class TestPhi(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -7,8 +7,8 @@ import os
import unittest import unittest
from pathlib import Path from pathlib import Path
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -77,7 +77,7 @@ class TestReLoraLlama(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg) check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
assert ( assert (
Path(temp_dir) / "checkpoint-100/relora/model.safetensors" Path(temp_dir) / "checkpoint-100/relora/model.safetensors"

View File

@@ -6,8 +6,8 @@ import logging
import os import os
import unittest import unittest
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -69,5 +69,5 @@ class TestRewardModelLoraLlama(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)