Compare commits

...

41 Commits

Author SHA1 Message Date
Dan Saunders
66262c3092 moving out all diff attn code to plugin repo 2025-01-24 17:46:11 +00:00
Dan Saunders
016ba124e4 README update 2025-01-23 22:11:35 +00:00
Dan Saunders
7145d52d99 moving diff attn code to separate repo 2025-01-23 21:33:53 +00:00
Dan Saunders
28694219a5 inline comment change 2025-01-14 16:59:43 +00:00
Dan Saunders
fd8ad6fcbf fixing negative component mixing 2025-01-13 19:21:55 +00:00
Dan Saunders
661d71a14b adding diff attn negative component warmup (in progress) 2025-01-10 21:57:31 +00:00
Dan Saunders
6dd47edcb8 fire CLI fixes 2025-01-10 18:24:16 +00:00
Dan Saunders
7aca08ff60 adding guard statements 2025-01-10 16:39:21 +00:00
Dan Saunders
4f804f6d88 adding diff attn callback, adding documentation 2025-01-10 16:28:51 +00:00
Dan Saunders
443327c585 CLI build_command bugfix 2025-01-10 16:28:51 +00:00
Dan Saunders
70c4e6fbe6 updates and cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
2a7f139ad2 pre-commit fix 2025-01-10 16:28:51 +00:00
Dan Saunders
332ce0ae85 fixes and cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
e5fa842ff8 update 2025-01-10 16:28:51 +00:00
Dan Saunders
78e0ec0aa5 changes 2025-01-10 16:28:51 +00:00
Dan Saunders
3bc568eb27 adding registration function 2025-01-10 16:28:51 +00:00
Dan Saunders
eb6611d55f progress on modeling code 2025-01-10 16:28:51 +00:00
Dan Saunders
4ff3328e66 updated custom modeling code 2025-01-10 16:28:51 +00:00
Dan Saunders
a3fd5074a9 fix duplicate-code warnings 2025-01-10 16:28:51 +00:00
Dan Saunders
5b90da0be3 added modeling code; cleanup + refactor 2025-01-10 16:28:51 +00:00
Dan Saunders
fcbfa86373 refactor and fixing test isolation issues 2025-01-10 16:28:51 +00:00
Dan Saunders
0d56582090 adding yaml dumper preserving input config format 2025-01-10 16:28:51 +00:00
Dan Saunders
390cb5742e removing extra pytest xdist args 2025-01-10 16:28:51 +00:00
Dan Saunders
1d935f65c3 moving tests around for flash_attn install 2025-01-10 16:28:51 +00:00
Dan Saunders
66176b3e07 adding split_heads argument for retaining original (Q, K) dimensionanlity 2025-01-10 16:28:51 +00:00
Dan Saunders
505321ac95 isolating problematic test 2025-01-10 16:28:51 +00:00
Dan Saunders
0b382c88da fixes post-rebase 2025-01-10 16:28:51 +00:00
Dan Saunders
ea07a7086e plugin implementation 2025-01-10 16:28:51 +00:00
Dan Saunders
d22e1136bc convert-differential-transformer test coverage 2025-01-10 16:28:51 +00:00
Dan Saunders
63b8e42c6b duplicate code ignore 2025-01-10 16:28:51 +00:00
Dan Saunders
bda1eed59e differential flash attention 2; cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
41ebd93158 moving monkeypatch 2025-01-10 16:28:51 +00:00
Dan Saunders
4c050ce807 pre-commit fix 2025-01-10 16:28:51 +00:00
Dan Saunders
6665acf63d fix model save / load logic 2025-01-10 16:28:51 +00:00
Dan Saunders
2f9fa4c465 various improvemnents 2025-01-10 16:28:51 +00:00
Dan Saunders
849bc94112 various improvemnents 2025-01-10 16:28:51 +00:00
Dan Saunders
e484ec778d training fixes, patching, minor cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
df1504ae14 adding CLI command for convert-diff-transformer 2025-01-10 16:28:51 +00:00
Dan Saunders
7be0d7496c Adding script for doing conversion; fixes and updates 2025-01-10 16:28:51 +00:00
Dan Saunders
13cdffa91f initial diff attn layer / model conversion implementation (support for llama arch) 2025-01-10 16:28:51 +00:00
Dan Saunders
7a4b296f60 Basic evaluate CLI command / codepath (#2188)
* basic evaluate CLI command / codepath

* tests for evaluate CLI command

* fixes and cleanup

* review comments; slightly DRYing up things

---------

Co-authored-by: Dan Saunders <danjsaund@gmail.com>
2025-01-10 16:28:51 +00:00
25 changed files with 296 additions and 57 deletions

3
.gitignore vendored
View File

@@ -186,3 +186,6 @@ out/
# vim
*.swp
# symlinked to axolotl-artifacts in docker containers
outputs

View File

@@ -4,7 +4,6 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -1,6 +1,6 @@
"""
modal application to run axolotl gpu tests in Modal
"""
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code
import os

View File

@@ -202,7 +202,7 @@ def do_inference(
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)

View File

@@ -3,7 +3,7 @@ CLI to run training on a model
"""
import logging
from pathlib import Path
from typing import Union
from typing import Dict, Union
import fire
from dotenv import load_dotenv
@@ -23,7 +23,7 @@ from axolotl.evaluate import evaluate
LOG = logging.getLogger("axolotl.cli.evaluate")
def do_evaluate(cfg, cli_args) -> None:
def do_evaluate(cfg, cli_args) -> Dict[str, float]:
# pylint: disable=duplicate-code
print_axolotl_text_art()
check_accelerate_default_config()
@@ -34,7 +34,7 @@ def do_evaluate(cfg, cli_args) -> None:
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
return evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:

View File

@@ -1,11 +1,13 @@
"""CLI definition for various axolotl commands."""
# pylint: disable=redefined-outer-name
import subprocess # nosec B404
from typing import Optional
import click
import axolotl
from axolotl.cli.plugins import setup_plugin_commands
from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
@@ -77,6 +79,9 @@ def evaluate(config: str, accelerate: bool, **kwargs):
"""Evaluate a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
if config:
@@ -254,6 +259,9 @@ def fetch(directory: str, dest: Optional[str]):
fetch_from_github(f"{directory}/", dest)
setup_plugin_commands(cli)
def main():
cli()

View File

@@ -0,0 +1,36 @@
"""Module for adding click CLI commands from axolotl plugins."""
import logging
import click
from axolotl.cli.utils import add_options_from_config, add_options_from_dataclass
from axolotl.logging_config import configure_logging
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
configure_logging()
LOG = logging.getLogger(__name__)
def setup_plugin_commands(cli: click.core.Group) -> None:
"""
Setup CLI commands for available plugins.
Args:
cli: Click CLI object to add plugin CLI options to.
"""
try:
from axolotl_diff_transformer.convert_diff_transformer import do_cli
from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
do_cli(config=config, **kwargs)
except ImportError as exc:
LOG.debug("axolotl-diff-transformer not found: %s", exc)

View File

@@ -22,7 +22,6 @@ def add_options_from_dataclass(config_class: Type[Any]):
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = field.type
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
@@ -44,6 +43,7 @@ def add_options_from_dataclass(config_class: Type[Any]):
default=field.default,
help=field.metadata.get("description"),
)(function)
return function
return decorator
@@ -55,7 +55,14 @@ def add_options_from_config(config_class: Type[BaseModel]):
def decorator(function):
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
if field.annotation == bool:
field_type = field.annotation
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
# NOTE: defaults are handled by the pydantic model config classes.
if field_type == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
@@ -66,6 +73,7 @@ def add_options_from_config(config_class: Type[BaseModel]):
function = click.option(
option_name, default=None, help=field.description
)(function)
return function
return decorator
@@ -84,6 +92,8 @@ def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
if isinstance(value, bool):
if value:
cmd.append(f"--{key}")
else:
cmd.append(f"--no{key}")
else:
cmd.extend([f"--{key}", str(value)])

View File

@@ -4,22 +4,26 @@ shared module for cli specific things
import logging
from dataclasses import dataclass, field
from typing import Optional
from typing import TYPE_CHECKING, Optional, Union
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
if TYPE_CHECKING:
try:
from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs
except: # noqa: E722 # pylint: disable=bare-except # nosec B110
pass
configure_logging()
LOG = logging.getLogger("axolotl.common.cli")
LOG = logging.getLogger(__name__)
@dataclass
class PreprocessCliArgs:
"""
dataclass representing arguments for preprocessing only
"""
"""dataclass with arguments for preprocessing only"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
@@ -30,9 +34,7 @@ class PreprocessCliArgs:
@dataclass
class TrainerCliArgs:
"""
dataclass representing the various non-training arguments
"""
"""dataclass with various non-training arguments"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
@@ -45,9 +47,7 @@ class TrainerCliArgs:
@dataclass
class EvaluateCliArgs:
"""
dataclass representing the various evaluation arguments
"""
"""dataclass with various evaluation arguments"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
@@ -57,7 +57,7 @@ class EvaluateCliArgs:
def load_model_and_tokenizer(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
cli_args: Union[TrainerCliArgs, EvaluateCliArgs, "ConvertDiffTransformerCliArgs"],
):
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)

View File

@@ -293,7 +293,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
This code is duplicated due to HF TrainingArguments not setting output_dir with a default value
so it can't be used as a mixin.
"""

View File

@@ -9,12 +9,11 @@ from typing import Dict, Optional
import torch
from accelerate.logging import get_logger
from axolotl.common.cli import TrainerCliArgs
from axolotl.common.cli import EvaluateCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.models import load_processor
from axolotl.utils.trainer import setup_trainer
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -62,8 +61,9 @@ def evaluate_dataset(
return metrics
# pylint: disable=duplicate-code
def evaluate(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
*, cfg: DictDefault, cli_args: EvaluateCliArgs, dataset_meta: TrainDatasetMeta
) -> Dict[str, float]:
"""
Evaluate a model on training and validation datasets
@@ -79,16 +79,11 @@ def evaluate(
- The tokenizer
- Dictionary of evaluation metrics
"""
# pylint: disable=duplicate-code
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
# Load model
LOG.debug("loading model for evaluation...")
# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
# Load processor for multimodal models if needed
processor = None
@@ -100,12 +95,6 @@ def evaluate(
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Load model
LOG.debug("loading model for evaluation...")
model, _ = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
# Set up trainer
trainer = setup_trainer(
cfg,

View File

@@ -43,10 +43,12 @@ def merge_input_args():
input_args: List[str] = plugin_manager.get_input_args()
plugin_classes = []
dynamic_input = ""
for plugin_args in input_args:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
plugin_classes.append(plugin_cls)
if dynamic_input:
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
@@ -62,4 +64,5 @@ def merge_input_args():
"AxolotlConfigWCapabilities"
]
return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase

View File

@@ -713,19 +713,45 @@ class ModelLoader:
if self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
if self.cfg.diff_attention:
self.model_kwargs[
"attn_implementation"
] = "differential_flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_flash_attention_2"
)
else:
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
elif self.cfg.sdp_attention:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"sdpa"
)
if self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_sdpa"
)
else:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"sdpa"
)
elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager"
if self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_eager"
)
else:
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
elif self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
"differential_eager"
)
if self.cfg.low_cpu_mem_usage:
@@ -816,6 +842,7 @@ class ModelLoader:
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained(
self.base_model,
config=self.model_config,

157
src/axolotl/utils/yaml.py Normal file
View File

@@ -0,0 +1,157 @@
"""Utilities for YAML files."""
from collections import OrderedDict
from typing import Any, Dict, List, Set, Tuple, Union
import yaml
class YAMLOrderTracker:
"""Tracks the order of keys and section breaks in YAML files."""
def __init__(self, yaml_path: str):
self.yaml_path = yaml_path
self.structure, self.needs_break = self._parse_yaml_structure()
def _get_indentation_level(self, line: str) -> int:
"""Get the indentation level of a line."""
return len(line) - len(line.lstrip())
def _parse_yaml_structure(
self,
) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]:
"""Parse the YAML file to extract structure and identify section breaks."""
with open(self.yaml_path, "r", encoding="utf-8") as file:
contents = file.readlines()
structure: OrderedDict = OrderedDict()
needs_break = set() # Track which keys should have a break before them
current_path = []
last_indentation = -1
had_empty_line = False
for line in contents:
# Track empty lines and comments
if not line.strip() or line.strip().startswith("#"):
had_empty_line = True
continue
# Get indentation level and content
indentation = self._get_indentation_level(line)
content = line.strip()
# Skip lines that don't define keys
if ":" not in content:
continue
# Extract key
key = content.split(":")[0].strip()
# If this is a top-level key and we had an empty line, mark it
if indentation == 0:
if had_empty_line:
needs_break.add(key)
had_empty_line = False
# Handle indentation changes
if indentation > last_indentation:
current_path.append(key)
elif indentation < last_indentation:
levels_up = (last_indentation - indentation) // 2
current_path = current_path[:-levels_up]
current_path[-1] = key
else:
if current_path:
current_path[-1] = key
# Update structure
current_dict = structure
for path_key in current_path[:-1]:
if path_key not in current_dict:
current_dict[path_key] = OrderedDict()
current_dict = current_dict[path_key]
if current_path:
if current_path[-1] not in current_dict:
current_dict[current_path[-1]] = OrderedDict()
last_indentation = indentation
return structure, needs_break
class OrderedDumper(yaml.SafeDumper):
"""Custom YAML dumper that maintains dictionary order."""
def represent_none(self, _):
"""Represent None values as empty fields."""
return self.represent_scalar("tag:yaml.org,2002:null", "")
def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any:
"""Custom representer for dictionaries that maintains order."""
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict:
"""Reorder a dictionary based on a reference structure."""
ordered = OrderedDict()
# First add keys that are in the reference order
for key in reference_structure:
if key in data:
if isinstance(reference_structure[key], dict) and isinstance(
data[key], dict
):
ordered[key] = reorder_dict(data[key], reference_structure[key])
else:
ordered[key] = data[key]
# Then add any remaining keys that weren't in the reference
for key in data:
if key not in ordered:
ordered[key] = data[key]
return ordered
def dump_yaml_preserved_order(
data: Dict, reference_yaml_path: str, output_path: str
) -> None:
"""Dump YAML file while preserving nested order and normalized spacing."""
# Get reference structure and spacing
tracker = YAMLOrderTracker(reference_yaml_path)
# Reorder the data
ordered_data = reorder_dict(data, tracker.structure)
# Register the custom representers
OrderedDumper.add_representer(type(None), represent_none)
OrderedDumper.add_representer(dict, ordered_dict_representer)
OrderedDumper.add_representer(OrderedDict, ordered_dict_representer)
# First dump to string
yaml_str = yaml.dump(
ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False
)
# Add spacing according to reference
lines = yaml_str.split("\n")
result_lines: List[str] = []
current_line = 0
while current_line < len(lines):
line = lines[current_line]
if line.strip() and ":" in line and not line.startswith(" "): # Top-level key
key = line.split(":")[0].strip()
if key in tracker.needs_break:
# Add single empty line before this key
if result_lines and result_lines[-1] != "":
result_lines.append("")
result_lines.append(line)
current_line += 1
# Write the final result
with open(output_path, "w", encoding="utf-8") as file:
file.write("\n".join(result_lines))

View File

@@ -1,4 +1,5 @@
"""Shared pytest fixtures for cli module."""
import pytest
from click.testing import CliRunner

View File

@@ -43,14 +43,12 @@ class BaseCliTest:
result = cli_runner.invoke(cli, [command, str(config_path)])
assert mock.called
assert mock.call_args.args[0] == [
assert mock.call_args.args[0][:5] == [
"accelerate",
"launch",
"-m",
f"axolotl.cli.{command}",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI fetch command."""
from unittest.mock import patch
from axolotl.cli.main import fetch

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI inference command."""
from unittest.mock import patch
from axolotl.cli.main import cli

View File

@@ -1,4 +1,5 @@
"""General pytest tests for axolotl.cli.main interface."""
from axolotl.cli.main import build_command, cli
@@ -22,6 +23,7 @@ def test_build_command():
"--batch-size",
"8",
"--debug",
"--nouse-fp16",
]

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI merge_lora command."""
from unittest.mock import patch
from axolotl.cli.main import cli

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI preprocess command."""
import shutil
from pathlib import Path
from unittest.mock import patch

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI shard command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
@@ -11,14 +12,12 @@ def test_shard_with_accelerate(cli_runner, config_path):
result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"])
assert mock.called
assert mock.call_args.args[0] == [
assert mock.call_args.args[0][:5] == [
"accelerate",
"launch",
"-m",
"axolotl.cli.shard",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI --version"""
from axolotl.cli.main import cli

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI utils."""
# pylint: disable=redefined-outer-name
import json
from unittest.mock import Mock, patch