continued cleanup and documentation
This commit is contained in:
@@ -34,7 +34,12 @@ def print_dep_versions():
|
|||||||
|
|
||||||
|
|
||||||
def print_legacy_axolotl_text_art(suffix=None):
|
def print_legacy_axolotl_text_art(suffix=None):
|
||||||
"""Prints axolotl ASCII art and dependency versions."""
|
"""
|
||||||
|
Prints axolotl ASCII art and dependency versions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
suffix: Text to append to ASCII art text.
|
||||||
|
"""
|
||||||
font = "nancyj"
|
font = "nancyj"
|
||||||
ascii_text = " axolotl"
|
ascii_text = " axolotl"
|
||||||
if suffix:
|
if suffix:
|
||||||
|
|||||||
@@ -28,7 +28,24 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def check_remote_config(config: Union[str, Path]):
|
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||||
|
"""
|
||||||
|
First, determines if the passed config is a valid HTTPS URL. Then, attempts to query
|
||||||
|
for it and parse its content, first as JSON, then as YAML (YAML is preferred).
|
||||||
|
Finally, the parsed content is written to a local file and its path is returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: HTTPS URL to a YAML or JSON file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either the original `config` if it's not a valid HTTPS URL, or the path to the
|
||||||
|
downloaded remote config.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the remote configuration is neither valid JSON or YAML.
|
||||||
|
RuntimeError: If some request-related exception occurs from the file download.
|
||||||
|
Exception: Catch-all for any other exception.
|
||||||
|
"""
|
||||||
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
|
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
|
||||||
if not (isinstance(config, str) and config.startswith("https://")):
|
if not (isinstance(config, str) and config.startswith("https://")):
|
||||||
return config # Return the original value if it's not a valid URL
|
return config # Return the original value if it's not a valid URL
|
||||||
@@ -42,9 +59,12 @@ def check_remote_config(config: Union[str, Path]):
|
|||||||
|
|
||||||
content = response.content
|
content = response.content
|
||||||
try:
|
try:
|
||||||
# Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML
|
# Try parsing as JSON first to catch cases where JSON content is mistakenly
|
||||||
|
# considered YAML.
|
||||||
json.loads(content)
|
json.loads(content)
|
||||||
# Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link
|
|
||||||
|
# Log a warning but do not raise an error; JSON is technically valid YAML.
|
||||||
|
# This can happen when you forget to point to a raw GitHub link.
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
|
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
|
||||||
)
|
)
|
||||||
@@ -75,7 +95,22 @@ def check_remote_config(config: Union[str, Path]):
|
|||||||
|
|
||||||
|
|
||||||
def choose_config(path: Path):
|
def choose_config(path: Path):
|
||||||
yaml_files = list(path.glob("*.yml"))
|
"""
|
||||||
|
Helper method for choosing a `axolotl` config YAML file (considering only files
|
||||||
|
ending with `.yml` or `.yaml`). If more than one config file exists in the passed
|
||||||
|
`path`, the user is prompted to choose one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Directory in which config file(s) are stored.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to either (1) the sole YAML file, or (2) if more than one YAML files exist,
|
||||||
|
the user-selected YAML file.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no YAML files are found in the given `path`.
|
||||||
|
"""
|
||||||
|
yaml_files = list(path.glob("*.yml")) + list(path.glob("*.yaml"))
|
||||||
|
|
||||||
if not yaml_files:
|
if not yaml_files:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -104,11 +139,13 @@ def choose_config(path: Path):
|
|||||||
return chosen_file
|
return chosen_file
|
||||||
|
|
||||||
|
|
||||||
def prepare_plugins(cfg):
|
def prepare_plugins(cfg: DictDefault):
|
||||||
"""
|
|
||||||
Prepare the plugins for the configuration
|
|
||||||
"""
|
"""
|
||||||
|
Registers the plugins for the given configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
"""
|
||||||
if cfg.get("plugins"):
|
if cfg.get("plugins"):
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
for plugin_name in cfg["plugins"]:
|
for plugin_name in cfg["plugins"]:
|
||||||
@@ -116,15 +153,27 @@ def prepare_plugins(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||||
|
"""
|
||||||
|
Loads the `axolotl` configuration stored at `config`, validates it, and performs
|
||||||
|
various setup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Path (local or remote) to `axolotl` config YAML file.
|
||||||
|
kwargs: Additional keyword arguments to override config file values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`DictDefault` mapping configuration keys to values.
|
||||||
|
"""
|
||||||
config = check_remote_config(config)
|
config = check_remote_config(config)
|
||||||
if Path(config).is_dir():
|
if Path(config).is_dir():
|
||||||
config = choose_config(Path(config))
|
config = choose_config(Path(config))
|
||||||
|
|
||||||
# load the config from the yaml file
|
# Load the config from the yaml file
|
||||||
with open(config, encoding="utf-8") as file:
|
with open(config, encoding="utf-8") as file:
|
||||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
|
||||||
# then overwrite the value
|
# If there are any options passed in the cli, if it is something that seems valid
|
||||||
|
# from the yaml, then overwrite the value
|
||||||
cfg_keys = cfg.keys()
|
cfg_keys = cfg.keys()
|
||||||
for k, _ in kwargs.items():
|
for k, _ in kwargs.items():
|
||||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||||
|
|||||||
@@ -11,14 +11,24 @@ from transformers.hf_argparser import HfArgumentParser
|
|||||||
from axolotl.cli.art import print_axolotl_text_art
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.datasets import load_datasets, load_rl_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets, load_rl_datasets
|
||||||
from axolotl.evaluate import evaluate
|
from axolotl.evaluate import evaluate
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def do_evaluate(cfg, cli_args) -> None:
|
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||||
|
"""
|
||||||
|
Evaluates a `transformers` model by first loading the dataset(s) specified in the
|
||||||
|
`axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes
|
||||||
|
evaluation metrics on the given dataset(s) and writes them to disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
cli_args: CLI arguments.
|
||||||
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
@@ -33,6 +43,13 @@ def do_evaluate(cfg, cli_args) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Parses `axolotl` config, CLI args, and calls `do_evaluate`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Path to `axolotl` config YAML file.
|
||||||
|
kwargs: Additional keyword arguments to override config file values.
|
||||||
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
parser = HfArgumentParser(TrainerCliArgs)
|
parser = HfArgumentParser(TrainerCliArgs)
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ def do_inference(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
cli_args: Training-specific CLI arguments.
|
cli_args: Inference-specific CLI arguments.
|
||||||
"""
|
"""
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
|
||||||
prompter = cli_args.prompter
|
prompter = cli_args.prompter
|
||||||
@@ -145,7 +145,7 @@ def do_inference_gradio(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
cli_args: Training-specific CLI arguments.
|
cli_args: Inference-specific CLI arguments.
|
||||||
"""
|
"""
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
@@ -246,8 +246,7 @@ def do_cli(
|
|||||||
config: Union[Path, str] = Path("examples/"), gradio: bool = False, **kwargs
|
config: Union[Path, str] = Path("examples/"), gradio: bool = False, **kwargs
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Parses axolotl config, training-specific CLI args, and calls `do_inference` or
|
Parses axolotl config, CLI args, and calls `do_inference` or `do_inference_gradio`.
|
||||||
`do_inference_gradio` as a subroutine.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Path to `axolotl` config YAML file.
|
config: Path to `axolotl` config YAML file.
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ def cli():
|
|||||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
@add_options_from_dataclass(PreprocessCliArgs)
|
@add_options_from_dataclass(PreprocessCliArgs)
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
def preprocess(config: str, **kwargs):
|
def preprocess(config: str, **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Preprocess datasets before training.
|
Preprocess datasets before training.
|
||||||
|
|
||||||
@@ -148,39 +148,6 @@ def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
|
|||||||
do_cli(config=config, gradio=gradio, **kwargs)
|
do_cli(config=config, gradio=gradio, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
|
||||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
|
||||||
@click.option(
|
|
||||||
"--accelerate/--no-accelerate",
|
|
||||||
default=False,
|
|
||||||
help="Use accelerate launch for multi-GPU operations",
|
|
||||||
)
|
|
||||||
@add_options_from_dataclass(TrainerCliArgs)
|
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
|
||||||
def shard(config: str, accelerate: bool, **kwargs) -> None:
|
|
||||||
"""
|
|
||||||
Shard model weights into chunks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Path to `axolotl` config YAML file.
|
|
||||||
accelerate: Whether to use `accelerate` launcher.
|
|
||||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
|
||||||
config options.
|
|
||||||
"""
|
|
||||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
|
||||||
|
|
||||||
if accelerate:
|
|
||||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.shard"]
|
|
||||||
if config:
|
|
||||||
base_cmd.append(config)
|
|
||||||
cmd = build_command(base_cmd, kwargs)
|
|
||||||
subprocess.run(cmd, check=True) # nosec B603
|
|
||||||
else:
|
|
||||||
from axolotl.cli.shard import do_cli
|
|
||||||
|
|
||||||
do_cli(config=config, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
@click.option(
|
@click.option(
|
||||||
|
|||||||
@@ -49,9 +49,9 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
|||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Parses `axolotl` config, training-specific CLI args, and calls `do_merge_lora` as a subroutine.
|
Parses `axolotl` config, CLI args, and calls `do_merge_lora`. Note that various
|
||||||
Note that various config values will be overwritten to allow the LoRA merge logic to work
|
config values will be overwritten to allow the LoRA merge logic to work as expected
|
||||||
as expected (`load_in_8bit=False`, `load_in4bit=False`, `flash_attention=False`, etc.).
|
(`load_in_8bit=False`, `load_in4bit=False`, `flash_attention=False`, etc.).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Path to `axolotl` config YAML file.
|
config: Path to `axolotl` config YAML file.
|
||||||
|
|||||||
@@ -32,9 +32,7 @@ LOG = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
|
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
|
||||||
"""
|
"""A custom planner to cast tensors to bfloat16 on the fly during loading."""
|
||||||
A custom planner to cast tensors to bfloat16 on the fly during loading.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument
|
def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument
|
||||||
tensor.copy_(tensor.to(torch.bfloat16))
|
tensor.copy_(tensor.to(torch.bfloat16))
|
||||||
@@ -45,11 +43,19 @@ def _distributed_checkpoint_to_merged_weights(
|
|||||||
save_path: str,
|
save_path: str,
|
||||||
safe_serialization: bool = False,
|
safe_serialization: bool = False,
|
||||||
max_shard_size: str = "5GB",
|
max_shard_size: str = "5GB",
|
||||||
):
|
) -> Path:
|
||||||
"""
|
"""
|
||||||
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`
|
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`. Will
|
||||||
|
save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
|
||||||
|
|
||||||
Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
|
Args:
|
||||||
|
checkpoint_dir: Directory where distributed checkpoint is saved.
|
||||||
|
save_path: Path to save model to.
|
||||||
|
safe_serialization: Whether to save in safetensors format.
|
||||||
|
max_shard_size: Max size of model shards to save.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path where model is saved.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
state_dict: Dict = {}
|
state_dict: Dict = {}
|
||||||
@@ -79,6 +85,7 @@ def _distributed_checkpoint_to_merged_weights(
|
|||||||
state_dict_split = split_torch_state_dict_into_shards(
|
state_dict_split = split_torch_state_dict_into_shards(
|
||||||
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save index if sharded
|
# Save index if sharded
|
||||||
index = None
|
index = None
|
||||||
if state_dict_split.is_sharded:
|
if state_dict_split.is_sharded:
|
||||||
@@ -135,6 +142,9 @@ def merge_fsdp_weights(
|
|||||||
Whether to save the merged weights with safetensors (recommended).
|
Whether to save the merged weights with safetensors (recommended).
|
||||||
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
|
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to remove the checkpoint directory after merging.
|
Whether to remove the checkpoint directory after merging.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist.
|
||||||
"""
|
"""
|
||||||
checkpoint_dir_ = Path(checkpoint_dir)
|
checkpoint_dir_ = Path(checkpoint_dir)
|
||||||
from accelerate.state import PartialState
|
from accelerate.state import PartialState
|
||||||
@@ -178,18 +188,21 @@ def merge_fsdp_weights(
|
|||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
|
"""
|
||||||
|
Parses `axolotl` config, CLI args, and calls `merge_fsdp_weights`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Path to `axolotl` config YAML file.
|
||||||
|
kwargs: Additional keyword arguments to override config file values.
|
||||||
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
parser = transformers.HfArgumentParser(TrainerCliArgs)
|
||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
parsed_cli_args.merge_lora = True
|
parsed_cli_args.merge_lora = True
|
||||||
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
parsed_cfg = load_cfg(
|
|
||||||
config,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
||||||
merge_fsdp_weights(
|
merge_fsdp_weights(
|
||||||
|
|||||||
@@ -15,9 +15,9 @@ from transformers import AutoModelForCausalLM
|
|||||||
from axolotl.cli.art import print_axolotl_text_art
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.datasets import load_datasets, load_rl_datasets
|
|
||||||
from axolotl.common.cli import PreprocessCliArgs
|
from axolotl.common.cli import PreprocessCliArgs
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
|
from axolotl.common.datasets import load_datasets, load_rl_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.trainer import disable_datasets_caching
|
from axolotl.utils.trainer import disable_datasets_caching
|
||||||
|
|
||||||
@@ -77,7 +77,7 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
|||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Parses `axolotl` config, preprocessing-specific CLI args, and calls `do_preprocess` as a subroutine.
|
Parses `axolotl` config, CLI args, and calls `do_preprocess`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Path to `axolotl` config YAML file.
|
config: Path to `axolotl` config YAML file.
|
||||||
|
|||||||
@@ -1,33 +0,0 @@
|
|||||||
"""CLI to shard a trained model into 10GiB chunks."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import fire
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
|
||||||
from axolotl.common.cli import load_model_and_tokenizer
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def shard(*, cfg: DictDefault):
|
|
||||||
model, _ = load_model_and_tokenizer(cfg=cfg)
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
|
||||||
LOG.debug("Re-saving model w/ sharding")
|
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|
||||||
print_axolotl_text_art()
|
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
|
||||||
shard(cfg=parsed_cfg)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
load_dotenv()
|
|
||||||
fire.Fire(do_cli)
|
|
||||||
@@ -11,8 +11,8 @@ from transformers.hf_argparser import HfArgumentParser
|
|||||||
from axolotl.cli.art import print_axolotl_text_art
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.datasets import load_datasets, load_rl_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets, load_rl_datasets
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -22,8 +22,8 @@ LOG = logging.getLogger(__name__)
|
|||||||
|
|
||||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||||
"""
|
"""
|
||||||
Trains transformer model by first loading the dataset(s) specified in the `axolotl`
|
Trains a `transformers` model by first loading the dataset(s) specified in the
|
||||||
config, and then calling `axolotl.train.train` as a subroutine. Also runs the plugin
|
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
||||||
manager's `post_train_unload` once training completes.
|
manager's `post_train_unload` once training completes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -50,7 +50,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Parses `axolotl` config, training-specific CLI args, and calls `do_train` as a subroutine.
|
Parses `axolotl` config, CLI args, and calls `do_train`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Path to `axolotl` config YAML file.
|
config: Path to `axolotl` config YAML file.
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
"""Dataset loading utilities."""
|
"""Dataset loading utilities."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import Union
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs
|
from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs
|
||||||
from axolotl.train import TrainDatasetMeta
|
|
||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets
|
from axolotl.utils.data.rl import load_prepare_dpo_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -15,11 +18,48 @@ from axolotl.utils.tokenization import check_dataset_labels
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainDatasetMeta:
|
||||||
|
"""Dataclass with fields for training and validation datasets and metadata."""
|
||||||
|
|
||||||
|
train_dataset: Dataset
|
||||||
|
eval_dataset: Optional[Dataset] = None
|
||||||
|
total_num_steps: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
def sample_dataset(dataset: Dataset, num_samples: int):
|
||||||
|
"""
|
||||||
|
Randomly sample `num_samples` samples from `dataset`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Dataset.
|
||||||
|
num_samples: Number of samples to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Random sample (with replacement) of examples in `dataset`.
|
||||||
|
"""
|
||||||
|
return dataset.select(
|
||||||
|
[random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_datasets(
|
def load_datasets(
|
||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
|
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
|
||||||
) -> TrainDatasetMeta:
|
) -> TrainDatasetMeta:
|
||||||
|
"""
|
||||||
|
Loads one or more training or evaluation datasets, calling
|
||||||
|
`axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
cli_args: Command-specific CLI arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataclass with fields for training and evaluation datasets and the computed
|
||||||
|
`total_num_steps`.
|
||||||
|
"""
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
||||||
|
|
||||||
@@ -36,13 +76,10 @@ def load_datasets(
|
|||||||
or int(cli_args.debug_num_examples) > 0
|
or int(cli_args.debug_num_examples) > 0
|
||||||
):
|
):
|
||||||
LOG.info("check_dataset_labels...")
|
LOG.info("check_dataset_labels...")
|
||||||
|
|
||||||
|
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
|
||||||
check_dataset_labels(
|
check_dataset_labels(
|
||||||
train_dataset.select(
|
train_samples,
|
||||||
[
|
|
||||||
random.randrange(0, len(train_dataset) - 1) # nosec
|
|
||||||
for _ in range(cli_args.debug_num_examples)
|
|
||||||
]
|
|
||||||
),
|
|
||||||
tokenizer,
|
tokenizer,
|
||||||
num_examples=cli_args.debug_num_examples,
|
num_examples=cli_args.debug_num_examples,
|
||||||
text_only=cli_args.debug_text_only,
|
text_only=cli_args.debug_text_only,
|
||||||
@@ -66,6 +103,19 @@ def load_rl_datasets(
|
|||||||
PreprocessCliArgs, TrainerCliArgs
|
PreprocessCliArgs, TrainerCliArgs
|
||||||
], # pylint: disable=unused-argument
|
], # pylint: disable=unused-argument
|
||||||
) -> TrainDatasetMeta:
|
) -> TrainDatasetMeta:
|
||||||
|
"""
|
||||||
|
Loads one or more training or evaluation datasets for RL training, calling
|
||||||
|
`axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug
|
||||||
|
information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
cli_args: Command-specific CLI arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataclass with fields for training and evaluation datasets and the computed
|
||||||
|
`total_num_steps`.
|
||||||
|
"""
|
||||||
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
|
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
@@ -75,13 +125,9 @@ def load_rl_datasets(
|
|||||||
LOG.info("check_dataset_labels...")
|
LOG.info("check_dataset_labels...")
|
||||||
|
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
|
||||||
check_dataset_labels(
|
check_dataset_labels(
|
||||||
train_dataset.select(
|
train_samples,
|
||||||
[
|
|
||||||
random.randrange(0, len(train_dataset) - 1) # nosec
|
|
||||||
for _ in range(cli_args.debug_num_examples)
|
|
||||||
]
|
|
||||||
),
|
|
||||||
tokenizer,
|
tokenizer,
|
||||||
num_examples=cli_args.debug_num_examples,
|
num_examples=cli_args.debug_num_examples,
|
||||||
text_only=cli_args.debug_text_only,
|
text_only=cli_args.debug_text_only,
|
||||||
@@ -5,20 +5,19 @@ import os
|
|||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import weakref
|
import weakref
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers.modelcard
|
import transformers.modelcard
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import save_fsdp_model
|
from accelerate.utils import save_fsdp_model
|
||||||
from datasets import Dataset
|
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
from pkg_resources import get_distribution # type: ignore
|
from pkg_resources import get_distribution # type: ignore
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
|
from axolotl.common.datasets import TrainDatasetMeta
|
||||||
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
|
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
|
||||||
fix_untrained_tokens,
|
fix_untrained_tokens,
|
||||||
)
|
)
|
||||||
@@ -41,15 +40,6 @@ configure_logging()
|
|||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TrainDatasetMeta:
|
|
||||||
"""Dataclass with fields for training and validation datasets and metadata."""
|
|
||||||
|
|
||||||
train_dataset: Dataset
|
|
||||||
eval_dataset: Optional[Dataset] = None
|
|
||||||
total_num_steps: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||||
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
||||||
|
|||||||
@@ -1,76 +0,0 @@
|
|||||||
"""pytest tests for axolotl CLI shard command."""
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from axolotl.cli.main import cli
|
|
||||||
|
|
||||||
|
|
||||||
def test_shard_with_accelerate(cli_runner, config_path):
|
|
||||||
"""Test shard command with accelerate"""
|
|
||||||
with patch("subprocess.run") as mock:
|
|
||||||
result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"])
|
|
||||||
|
|
||||||
assert mock.called
|
|
||||||
assert mock.call_args.args[0] == [
|
|
||||||
"accelerate",
|
|
||||||
"launch",
|
|
||||||
"-m",
|
|
||||||
"axolotl.cli.shard",
|
|
||||||
str(config_path),
|
|
||||||
"--debug-num-examples",
|
|
||||||
"0",
|
|
||||||
]
|
|
||||||
assert mock.call_args.kwargs == {"check": True}
|
|
||||||
assert result.exit_code == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_shard_no_accelerate(cli_runner, config_path):
|
|
||||||
"""Test shard command without accelerate"""
|
|
||||||
with patch("axolotl.cli.shard.do_cli") as mock:
|
|
||||||
result = cli_runner.invoke(cli, ["shard", str(config_path), "--no-accelerate"])
|
|
||||||
|
|
||||||
assert mock.called
|
|
||||||
assert result.exit_code == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_shard_with_model_dir(cli_runner, config_path, tmp_path):
|
|
||||||
"""Test shard command with model_dir option"""
|
|
||||||
model_dir = tmp_path / "model"
|
|
||||||
model_dir.mkdir()
|
|
||||||
|
|
||||||
with patch("axolotl.cli.shard.do_cli") as mock:
|
|
||||||
result = cli_runner.invoke(
|
|
||||||
cli,
|
|
||||||
[
|
|
||||||
"shard",
|
|
||||||
str(config_path),
|
|
||||||
"--no-accelerate",
|
|
||||||
"--model-dir",
|
|
||||||
str(model_dir),
|
|
||||||
],
|
|
||||||
catch_exceptions=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert mock.called
|
|
||||||
assert mock.call_args.kwargs["config"] == str(config_path)
|
|
||||||
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
|
|
||||||
assert result.exit_code == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_shard_with_save_dir(cli_runner, config_path):
|
|
||||||
with patch("axolotl.cli.shard.do_cli") as mock:
|
|
||||||
result = cli_runner.invoke(
|
|
||||||
cli,
|
|
||||||
[
|
|
||||||
"shard",
|
|
||||||
str(config_path),
|
|
||||||
"--no-accelerate",
|
|
||||||
"--save-dir",
|
|
||||||
"/path/to/save",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert mock.called
|
|
||||||
assert mock.call_args.kwargs["config"] == str(config_path)
|
|
||||||
assert mock.call_args.kwargs["save_dir"] == "/path/to/save"
|
|
||||||
assert result.exit_code == 0
|
|
||||||
@@ -4,8 +4,8 @@ Simple end-to-end test for Cut Cross Entropy integration
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils import get_pytorch_version
|
from axolotl.utils import get_pytorch_version
|
||||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||||
@@ -64,9 +64,9 @@ class TestCutCrossEntropyIntegration:
|
|||||||
major, minor, _ = get_pytorch_version()
|
major, minor, _ = get_pytorch_version()
|
||||||
if (major, minor) < (2, 4):
|
if (major, minor) < (2, 4):
|
||||||
with pytest.raises(ImportError):
|
with pytest.raises(ImportError):
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
else:
|
else:
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -92,7 +92,7 @@ class TestCutCrossEntropyIntegration:
|
|||||||
major, minor, _ = get_pytorch_version()
|
major, minor, _ = get_pytorch_version()
|
||||||
if (major, minor) < (2, 4):
|
if (major, minor) < (2, 4):
|
||||||
with pytest.raises(ImportError):
|
with pytest.raises(ImportError):
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
else:
|
else:
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from e2e.utils import require_torch_2_4_1
|
|||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
from axolotl.cli.datasets import load_datasets
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -60,7 +61,7 @@ class LigerIntegrationTestCase:
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@require_torch_2_4_1
|
@require_torch_2_4_1
|
||||||
@@ -105,5 +106,5 @@ class LigerIntegrationTestCase:
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -65,7 +65,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@@ -109,5 +109,5 @@ class Test4dMultipackLlama(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import os
|
|||||||
import pytest
|
import pytest
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -80,7 +80,7 @@ class TestFAXentropyLlama:
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -67,7 +67,7 @@ class TestFalconPatched(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@@ -107,5 +107,5 @@ class TestFalconPatched(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import unittest
|
|||||||
import pytest
|
import pytest
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -71,5 +71,5 @@ class TestFusedLlama(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import unittest
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -69,7 +69,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@@ -109,5 +109,5 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import unittest
|
|||||||
import pytest
|
import pytest
|
||||||
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
|
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -74,7 +74,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
|
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
|
||||||
@@ -124,5 +124,5 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -67,7 +67,7 @@ class TestMistral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@@ -108,5 +108,5 @@ class TestMistral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -64,7 +64,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@@ -102,7 +102,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
"MixtralFlashAttention2"
|
"MixtralFlashAttention2"
|
||||||
in model.model.layers[0].self_attn.__class__.__name__
|
in model.model.layers[0].self_attn.__class__.__name__
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import unittest
|
|||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
@@ -49,9 +48,8 @@ class TestModelPatches(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
model, _ = load_model(cfg, tokenizer, inference=False)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
"MixtralFlashAttention2"
|
"MixtralFlashAttention2"
|
||||||
@@ -87,9 +85,8 @@ class TestModelPatches(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
load_model(cfg, tokenizer, inference=cli_args.inference)
|
load_model(cfg, tokenizer, inference=False)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
"torch.jit"
|
"torch.jit"
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -67,7 +67,7 @@ class TestPhiMultipack(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@@ -118,5 +118,5 @@ class TestPhiMultipack(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import subprocess
|
|||||||
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -71,7 +71,7 @@ class TestResumeLlama:
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
resume_cfg = cfg | DictDefault(
|
resume_cfg = cfg | DictDefault(
|
||||||
{
|
{
|
||||||
@@ -81,7 +81,7 @@ class TestResumeLlama:
|
|||||||
normalize_config(resume_cfg)
|
normalize_config(resume_cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
|
|
||||||
train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=resume_cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
|
tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import os
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -75,7 +75,7 @@ class TestUnslothQLoRA:
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
@@ -125,7 +125,7 @@ class TestUnslothQLoRA:
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
@@ -180,7 +180,7 @@ class TestUnslothQLoRA:
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ from pathlib import Path
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_rl_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_rl_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -67,7 +67,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@@ -112,7 +112,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@@ -157,7 +157,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
|
|
||||||
@pytest.mark.skip("kto_pair no longer supported in trl")
|
@pytest.mark.skip("kto_pair no longer supported in trl")
|
||||||
@@ -202,7 +202,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@@ -246,7 +246,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@@ -293,7 +293,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Fix the implementation")
|
@pytest.mark.skip(reason="Fix the implementation")
|
||||||
@@ -357,5 +357,5 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -60,7 +60,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
@@ -104,7 +104,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -69,7 +69,7 @@ class TestFalcon(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@@ -122,7 +122,7 @@ class TestFalcon(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@@ -161,5 +161,5 @@ class TestFalcon(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ import os
|
|||||||
|
|
||||||
from e2e.utils import check_model_output_exists
|
from e2e.utils import check_model_output_exists
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -60,7 +60,7 @@ class TestLlama:
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
def test_fix_untrained_tokens(self, temp_dir):
|
def test_fix_untrained_tokens(self, temp_dir):
|
||||||
@@ -103,7 +103,7 @@ class TestLlama:
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
def test_batch_flattening(self, temp_dir):
|
def test_batch_flattening(self, temp_dir):
|
||||||
@@ -142,5 +142,5 @@ class TestLlama:
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -62,5 +62,5 @@ class TestPretrainLlama(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -66,7 +66,7 @@ class TestLlamaVision(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@@ -111,5 +111,5 @@ class TestLlamaVision(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -63,5 +63,5 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import unittest
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -63,5 +63,10 @@ class TestMamba(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
<<<<<<< HEAD
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
=======
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||||
|
>>>>>>> 2a421127 (continued cleanup and documentation)
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import unittest
|
|||||||
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -67,7 +67,7 @@ class TestMistral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@@ -110,5 +110,5 @@ class TestMistral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import unittest
|
|||||||
import torch
|
import torch
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -73,7 +73,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
@@ -127,7 +127,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
@@ -184,7 +184,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
@@ -241,7 +241,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
@@ -285,5 +285,5 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -63,7 +63,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@@ -107,7 +107,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@@ -143,5 +143,5 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import unittest
|
|||||||
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -63,7 +63,7 @@ class TestPackedLlama(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -65,7 +65,7 @@ class TestPhi(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@@ -114,5 +114,5 @@ class TestPhi(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -77,7 +77,7 @@ class TestReLoraLlama(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
|
||||||
assert (
|
assert (
|
||||||
Path(temp_dir) / "checkpoint-100/relora/model.safetensors"
|
Path(temp_dir) / "checkpoint-100/relora/model.safetensors"
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.datasets import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -69,5 +69,5 @@ class TestRewardModelLoraLlama(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user