Basic evaluate CLI command / codepath (#2188)
* basic evaluate CLI command / codepath * tests for evaluate CLI command * fixes and cleanup * review comments; slightly DRYing up things --------- Co-authored-by: Dan Saunders <dan@axolotl.ai>
This commit is contained in:
52
src/axolotl/cli/evaluate.py
Normal file
52
src/axolotl/cli/evaluate.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""
|
||||||
|
CLI to run training on a model
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import fire
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
|
from axolotl.cli import (
|
||||||
|
check_accelerate_default_config,
|
||||||
|
check_user_token,
|
||||||
|
load_cfg,
|
||||||
|
load_datasets,
|
||||||
|
load_rl_datasets,
|
||||||
|
print_axolotl_text_art,
|
||||||
|
)
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.evaluate import evaluate
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.cli.evaluate")
|
||||||
|
|
||||||
|
|
||||||
|
def do_evaluate(cfg, cli_args) -> None:
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
print_axolotl_text_art()
|
||||||
|
check_accelerate_default_config()
|
||||||
|
check_user_token()
|
||||||
|
|
||||||
|
if cfg.rl: # and cfg.rl != "orpo":
|
||||||
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
else:
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
|
||||||
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
parser = HfArgumentParser(TrainerCliArgs)
|
||||||
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
|
return_remaining_strings=True
|
||||||
|
)
|
||||||
|
do_evaluate(parsed_cfg, parsed_cli_args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
load_dotenv()
|
||||||
|
fire.Fire(do_cli)
|
||||||
@@ -12,7 +12,7 @@ from axolotl.cli.utils import (
|
|||||||
build_command,
|
build_command,
|
||||||
fetch_from_github,
|
fetch_from_github,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs
|
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -60,6 +60,31 @@ def train(config: str, accelerate: bool, **kwargs):
|
|||||||
do_cli(config=config, **kwargs)
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@click.option(
|
||||||
|
"--accelerate/--no-accelerate",
|
||||||
|
default=True,
|
||||||
|
help="Use accelerate launch for multi-GPU training",
|
||||||
|
)
|
||||||
|
@add_options_from_dataclass(EvaluateCliArgs)
|
||||||
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
|
def evaluate(config: str, accelerate: bool, **kwargs):
|
||||||
|
"""Evaluate a model."""
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
|
if accelerate:
|
||||||
|
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
|
||||||
|
if config:
|
||||||
|
base_cmd.append(config)
|
||||||
|
cmd = build_command(base_cmd, kwargs)
|
||||||
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
|
else:
|
||||||
|
from axolotl.cli.evaluate import do_cli
|
||||||
|
|
||||||
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
@click.option(
|
@click.option(
|
||||||
|
|||||||
@@ -15,6 +15,19 @@ configure_logging()
|
|||||||
LOG = logging.getLogger("axolotl.common.cli")
|
LOG = logging.getLogger("axolotl.common.cli")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PreprocessCliArgs:
|
||||||
|
"""
|
||||||
|
dataclass representing arguments for preprocessing only
|
||||||
|
"""
|
||||||
|
|
||||||
|
debug: bool = field(default=False)
|
||||||
|
debug_text_only: bool = field(default=False)
|
||||||
|
debug_num_examples: int = field(default=1)
|
||||||
|
prompter: Optional[str] = field(default=None)
|
||||||
|
download: Optional[bool] = field(default=True)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainerCliArgs:
|
class TrainerCliArgs:
|
||||||
"""
|
"""
|
||||||
@@ -31,16 +44,14 @@ class TrainerCliArgs:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PreprocessCliArgs:
|
class EvaluateCliArgs:
|
||||||
"""
|
"""
|
||||||
dataclass representing arguments for preprocessing only
|
dataclass representing the various evaluation arguments
|
||||||
"""
|
"""
|
||||||
|
|
||||||
debug: bool = field(default=False)
|
debug: bool = field(default=False)
|
||||||
debug_text_only: bool = field(default=False)
|
debug_text_only: bool = field(default=False)
|
||||||
debug_num_examples: int = field(default=1)
|
debug_num_examples: int = field(default=0)
|
||||||
prompter: Optional[str] = field(default=None)
|
|
||||||
download: Optional[bool] = field(default=True)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
@@ -50,7 +61,9 @@ def load_model_and_tokenizer(
|
|||||||
):
|
):
|
||||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
|
||||||
LOG.info("loading model and (optionally) peft_config...")
|
LOG.info("loading model and (optionally) peft_config...")
|
||||||
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
inference = getattr(cli_args, "inference", False)
|
||||||
|
model, _ = load_model(cfg, tokenizer, inference=inference)
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|||||||
168
src/axolotl/evaluate.py
Normal file
168
src/axolotl/evaluate.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""Module for evaluating models."""
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.logging_config import configure_logging
|
||||||
|
from axolotl.train import TrainDatasetMeta
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||||
|
from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer
|
||||||
|
|
||||||
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
src_dir = os.path.join(project_root, "src")
|
||||||
|
sys.path.insert(0, src_dir)
|
||||||
|
|
||||||
|
configure_logging()
|
||||||
|
LOG = get_logger("axolotl.evaluate")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_dataset(
|
||||||
|
trainer, dataset, dataset_type: str, flash_optimum: bool = False
|
||||||
|
) -> Optional[Dict[str, float]]:
|
||||||
|
"""Helper function to evaluate a single dataset safely.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trainer: The trainer instance
|
||||||
|
dataset: Dataset to evaluate
|
||||||
|
dataset_type: Type of dataset ('train' or 'eval')
|
||||||
|
flash_optimum: Whether to use flash optimum
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of metrics or None if dataset is None
|
||||||
|
"""
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
LOG.info(f"Starting {dataset_type} set evaluation...")
|
||||||
|
|
||||||
|
if flash_optimum:
|
||||||
|
with torch.backends.cuda.sdp_kernel(
|
||||||
|
enable_flash=True,
|
||||||
|
enable_math=True,
|
||||||
|
enable_mem_efficient=True,
|
||||||
|
):
|
||||||
|
metrics = trainer.evaluate(dataset, metric_key_prefix=dataset_type)
|
||||||
|
else:
|
||||||
|
metrics = trainer.evaluate(dataset, metric_key_prefix=dataset_type)
|
||||||
|
|
||||||
|
LOG.info(f"{dataset_type.capitalize()} set evaluation completed!")
|
||||||
|
LOG.info(f"{dataset_type.capitalize()} Metrics:")
|
||||||
|
for key, value in metrics.items():
|
||||||
|
LOG.info(f"{key}: {value}")
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Evaluate a model on training and validation datasets
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Configuration dictionary
|
||||||
|
cli_args: Command line arguments
|
||||||
|
dataset_meta: Dataset metadata containing training and evaluation datasets
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- The model (either PeftModel or PreTrainedModel)
|
||||||
|
- The tokenizer
|
||||||
|
- Dictionary of evaluation metrics
|
||||||
|
"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
|
set_pytorch_cuda_alloc_conf()
|
||||||
|
|
||||||
|
# Load tokenizer
|
||||||
|
LOG.debug(
|
||||||
|
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||||
|
main_process_only=True,
|
||||||
|
)
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
|
||||||
|
# Load processor for multimodal models if needed
|
||||||
|
processor = None
|
||||||
|
if cfg.is_multimodal:
|
||||||
|
processor = load_processor(cfg, tokenizer)
|
||||||
|
|
||||||
|
# Get datasets
|
||||||
|
train_dataset = dataset_meta.train_dataset
|
||||||
|
eval_dataset = dataset_meta.eval_dataset
|
||||||
|
total_num_steps = dataset_meta.total_num_steps
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
LOG.debug("loading model for evaluation...")
|
||||||
|
model, _ = load_model(
|
||||||
|
cfg, tokenizer, processor=processor, inference=cli_args.inference
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up trainer
|
||||||
|
trainer = setup_trainer(
|
||||||
|
cfg,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
model=(model, None, None), # No need for model_ref or peft_config
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
processor=processor,
|
||||||
|
total_num_steps=total_num_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate datasets
|
||||||
|
all_metrics = {}
|
||||||
|
train_metrics = evaluate_dataset(trainer, train_dataset, "train", cfg.flash_optimum)
|
||||||
|
eval_metrics = evaluate_dataset(trainer, eval_dataset, "eval", cfg.flash_optimum)
|
||||||
|
|
||||||
|
if train_metrics:
|
||||||
|
all_metrics.update(train_metrics)
|
||||||
|
if eval_metrics:
|
||||||
|
all_metrics.update(eval_metrics)
|
||||||
|
|
||||||
|
# Save metrics to CSV if output directory is specified and we have metrics
|
||||||
|
if cfg.output_dir and (train_metrics or eval_metrics):
|
||||||
|
output_dir = Path(cfg.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
metrics_file = output_dir / "eval_summary.csv"
|
||||||
|
with metrics_file.open("w", newline="", encoding="utf-8") as file:
|
||||||
|
writer = csv.writer(file)
|
||||||
|
writer.writerow(["metric", "training", "validation"])
|
||||||
|
|
||||||
|
# Get unique metric names (removing prefixes) from available metrics
|
||||||
|
train_metric_names = {
|
||||||
|
k.replace("train_", ""): k for k in (train_metrics or {})
|
||||||
|
}
|
||||||
|
eval_metric_names = {
|
||||||
|
k.replace("eval_", ""): k for k in (eval_metrics or {})
|
||||||
|
}
|
||||||
|
all_metric_names = sorted(
|
||||||
|
set(train_metric_names.keys()) | set(eval_metric_names.keys())
|
||||||
|
)
|
||||||
|
|
||||||
|
for metric_name in all_metric_names:
|
||||||
|
train_value = (
|
||||||
|
train_metrics.get(train_metric_names.get(metric_name, ""), "")
|
||||||
|
if train_metrics
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
eval_value = (
|
||||||
|
eval_metrics.get(eval_metric_names.get(metric_name, ""), "")
|
||||||
|
if eval_metrics
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
writer.writerow([metric_name, train_value, eval_value])
|
||||||
|
|
||||||
|
LOG.info(f"Evaluation results saved to {metrics_file}")
|
||||||
|
|
||||||
|
del model
|
||||||
|
del tokenizer
|
||||||
|
|
||||||
|
return all_metrics
|
||||||
@@ -24,7 +24,7 @@ from axolotl.logging_config import configure_logging
|
|||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
@@ -53,25 +53,22 @@ class TrainDatasetMeta:
|
|||||||
def train(
|
def train(
|
||||||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||||
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
||||||
# enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
torch_version = torch.__version__.split(".")
|
set_pytorch_cuda_alloc_conf()
|
||||||
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
|
||||||
if torch_major == 2 and torch_minor >= 2:
|
|
||||||
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
|
|
||||||
os.environ[
|
|
||||||
"PYTORCH_CUDA_ALLOC_CONF"
|
|
||||||
] = "expandable_segments:True,roundup_power2_divisions:16"
|
|
||||||
|
|
||||||
# load the tokenizer first
|
# Load tokenizer
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||||
main_process_only=True,
|
main_process_only=True,
|
||||||
)
|
)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
|
||||||
|
# Load processor for multimodal models if needed
|
||||||
processor = None
|
processor = None
|
||||||
if cfg.is_multimodal:
|
if cfg.is_multimodal:
|
||||||
processor = load_processor(cfg, tokenizer)
|
processor = load_processor(cfg, tokenizer)
|
||||||
|
|
||||||
|
# Get datasets
|
||||||
train_dataset = dataset_meta.train_dataset
|
train_dataset = dataset_meta.train_dataset
|
||||||
eval_dataset = dataset_meta.eval_dataset
|
eval_dataset = dataset_meta.eval_dataset
|
||||||
total_num_steps = dataset_meta.total_num_steps
|
total_num_steps = dataset_meta.total_num_steps
|
||||||
|
|||||||
@@ -119,7 +119,9 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
|||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
if cfg.dataset_exact_deduplication:
|
if cfg.dataset_exact_deduplication:
|
||||||
LOG.info("Deduplication not available for pretrained datasets")
|
LOG.info("Deduplication not available for pretrained datasets")
|
||||||
|
|
||||||
return train_dataset, eval_dataset, cfg.max_steps, prompters
|
return train_dataset, eval_dataset, cfg.max_steps, prompters
|
||||||
|
|
||||||
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
||||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||||
if total_eval_steps == 0:
|
if total_eval_steps == 0:
|
||||||
@@ -134,6 +136,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
|||||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||||
else:
|
else:
|
||||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
||||||
|
|
||||||
return train_dataset, eval_dataset, total_num_steps, prompters
|
return train_dataset, eval_dataset, total_num_steps, prompters
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -512,6 +512,17 @@ def prepare_opinionated_env(cfg):
|
|||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
def set_pytorch_cuda_alloc_conf():
|
||||||
|
"""Set up CUDA allocation config if using PyTorch >= 2.2"""
|
||||||
|
torch_version = torch.__version__.split(".")
|
||||||
|
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
||||||
|
if torch_major == 2 and torch_minor >= 2:
|
||||||
|
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
|
||||||
|
os.environ[
|
||||||
|
"PYTORCH_CUDA_ALLOC_CONF"
|
||||||
|
] = "expandable_segments:True,roundup_power2_divisions:16"
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(
|
def setup_trainer(
|
||||||
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
||||||
):
|
):
|
||||||
|
|||||||
73
tests/cli/test_cli_base.py
Normal file
73
tests/cli/test_cli_base.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""Base test class for CLI commands."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.cli.main import cli
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCliTest:
|
||||||
|
"""Base class for CLI command tests."""
|
||||||
|
|
||||||
|
def _test_cli_validation(self, cli_runner, command: str):
|
||||||
|
"""Test CLI validation for a command.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cli_runner: CLI runner fixture
|
||||||
|
command: Command to test (train/evaluate)
|
||||||
|
"""
|
||||||
|
# Test missing config file
|
||||||
|
result = cli_runner.invoke(cli, [command, "--no-accelerate"])
|
||||||
|
assert result.exit_code != 0
|
||||||
|
|
||||||
|
# Test non-existent config file
|
||||||
|
result = cli_runner.invoke(cli, [command, "nonexistent.yml", "--no-accelerate"])
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||||
|
|
||||||
|
def _test_basic_execution(
|
||||||
|
self, cli_runner, tmp_path: Path, valid_test_config: str, command: str
|
||||||
|
):
|
||||||
|
"""Test basic execution with accelerate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cli_runner: CLI runner fixture
|
||||||
|
tmp_path: Temporary path fixture
|
||||||
|
valid_test_config: Valid config fixture
|
||||||
|
command: Command to test (train/evaluate)
|
||||||
|
"""
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
|
with patch("subprocess.run") as mock:
|
||||||
|
result = cli_runner.invoke(cli, [command, str(config_path)])
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert mock.call_args.args[0] == [
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"-m",
|
||||||
|
f"axolotl.cli.{command}",
|
||||||
|
str(config_path),
|
||||||
|
"--debug-num-examples",
|
||||||
|
"0",
|
||||||
|
]
|
||||||
|
assert mock.call_args.kwargs == {"check": True}
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):
|
||||||
|
"""Test CLI argument overrides.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tmp_path: Temporary path fixture
|
||||||
|
valid_test_config: Valid config fixture
|
||||||
|
command: Command to test (train/evaluate)
|
||||||
|
"""
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
output_dir = tmp_path / "model-out"
|
||||||
|
|
||||||
|
test_config = valid_test_config.replace(
|
||||||
|
"output_dir: model-out", f"output_dir: {output_dir}"
|
||||||
|
)
|
||||||
|
config_path.write_text(test_config)
|
||||||
|
return config_path
|
||||||
67
tests/cli/test_cli_evaluate.py
Normal file
67
tests/cli/test_cli_evaluate.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""Tests for evaluate CLI command."""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.cli.main import cli
|
||||||
|
|
||||||
|
from .test_cli_base import BaseCliTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestEvaluateCommand(BaseCliTest):
|
||||||
|
"""Test cases for evaluate command."""
|
||||||
|
|
||||||
|
cli = cli
|
||||||
|
|
||||||
|
def test_evaluate_cli_validation(self, cli_runner):
|
||||||
|
"""Test CLI validation"""
|
||||||
|
self._test_cli_validation(cli_runner, "evaluate")
|
||||||
|
|
||||||
|
def test_evaluate_basic_execution(self, cli_runner, tmp_path, valid_test_config):
|
||||||
|
"""Test basic successful execution"""
|
||||||
|
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "evaluate")
|
||||||
|
|
||||||
|
def test_evaluate_basic_execution_no_accelerate(
|
||||||
|
self, cli_runner, tmp_path, valid_test_config
|
||||||
|
):
|
||||||
|
"""Test basic successful execution without accelerate"""
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
|
with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"evaluate",
|
||||||
|
str(config_path),
|
||||||
|
"--no-accelerate",
|
||||||
|
],
|
||||||
|
catch_exceptions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
mock_evaluate.assert_called_once()
|
||||||
|
|
||||||
|
def test_evaluate_cli_overrides(self, cli_runner, tmp_path, valid_test_config):
|
||||||
|
"""Test CLI arguments properly override config values"""
|
||||||
|
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
|
||||||
|
|
||||||
|
with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"evaluate",
|
||||||
|
str(config_path),
|
||||||
|
"--micro-batch-size",
|
||||||
|
"2",
|
||||||
|
"--sequence-len",
|
||||||
|
"128",
|
||||||
|
"--no-accelerate",
|
||||||
|
],
|
||||||
|
catch_exceptions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
mock_evaluate.assert_called_once()
|
||||||
|
cfg = mock_evaluate.call_args[0][0]
|
||||||
|
assert cfg.micro_batch_size == 2
|
||||||
|
assert cfg.sequence_len == 128
|
||||||
@@ -1,98 +1,71 @@
|
|||||||
"""pytest tests for axolotl CLI train command."""
|
"""Tests for train CLI command."""
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from axolotl.cli.main import cli
|
from axolotl.cli.main import cli
|
||||||
|
|
||||||
|
from .test_cli_base import BaseCliTest
|
||||||
def test_train_cli_validation(cli_runner):
|
|
||||||
"""Test CLI validation"""
|
|
||||||
# Test missing config file
|
|
||||||
result = cli_runner.invoke(cli, ["train", "--no-accelerate"])
|
|
||||||
assert result.exit_code != 0
|
|
||||||
|
|
||||||
# Test non-existent config file
|
|
||||||
result = cli_runner.invoke(cli, ["train", "nonexistent.yml", "--no-accelerate"])
|
|
||||||
assert result.exit_code != 0
|
|
||||||
assert "Error: Invalid value for 'CONFIG'" in result.output
|
|
||||||
|
|
||||||
|
|
||||||
def test_train_basic_execution(cli_runner, tmp_path, valid_test_config):
|
class TestTrainCommand(BaseCliTest):
|
||||||
"""Test basic successful execution"""
|
"""Test cases for train command."""
|
||||||
config_path = tmp_path / "config.yml"
|
|
||||||
config_path.write_text(valid_test_config)
|
|
||||||
|
|
||||||
with patch("subprocess.run") as mock:
|
cli = cli
|
||||||
result = cli_runner.invoke(cli, ["train", str(config_path)])
|
|
||||||
|
|
||||||
assert mock.called
|
def test_train_cli_validation(self, cli_runner):
|
||||||
assert mock.call_args.args[0] == [
|
"""Test CLI validation"""
|
||||||
"accelerate",
|
self._test_cli_validation(cli_runner, "train")
|
||||||
"launch",
|
|
||||||
"-m",
|
|
||||||
"axolotl.cli.train",
|
|
||||||
str(config_path),
|
|
||||||
"--debug-num-examples",
|
|
||||||
"0",
|
|
||||||
]
|
|
||||||
assert mock.call_args.kwargs == {"check": True}
|
|
||||||
assert result.exit_code == 0
|
|
||||||
|
|
||||||
|
def test_train_basic_execution(self, cli_runner, tmp_path, valid_test_config):
|
||||||
|
"""Test basic successful execution"""
|
||||||
|
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "train")
|
||||||
|
|
||||||
def test_train_basic_execution_no_accelerate(cli_runner, tmp_path, valid_test_config):
|
def test_train_basic_execution_no_accelerate(
|
||||||
"""Test basic successful execution"""
|
self, cli_runner, tmp_path, valid_test_config
|
||||||
config_path = tmp_path / "config.yml"
|
):
|
||||||
config_path.write_text(valid_test_config)
|
"""Test basic successful execution without accelerate"""
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
with patch("axolotl.cli.train.train") as mock_train:
|
with patch("axolotl.cli.train.train") as mock_train:
|
||||||
mock_train.return_value = (MagicMock(), MagicMock())
|
mock_train.return_value = (MagicMock(), MagicMock())
|
||||||
|
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
"train",
|
"train",
|
||||||
str(config_path),
|
str(config_path),
|
||||||
"--learning-rate",
|
"--no-accelerate",
|
||||||
"1e-4",
|
],
|
||||||
"--micro-batch-size",
|
catch_exceptions=False,
|
||||||
"2",
|
)
|
||||||
"--no-accelerate",
|
|
||||||
],
|
|
||||||
catch_exceptions=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
mock_train.assert_called_once()
|
mock_train.assert_called_once()
|
||||||
|
|
||||||
|
def test_train_cli_overrides(self, cli_runner, tmp_path, valid_test_config):
|
||||||
|
"""Test CLI arguments properly override config values"""
|
||||||
|
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
|
||||||
|
|
||||||
def test_train_cli_overrides(cli_runner, tmp_path, valid_test_config):
|
with patch("axolotl.cli.train.train") as mock_train:
|
||||||
"""Test CLI arguments properly override config values"""
|
mock_train.return_value = (MagicMock(), MagicMock())
|
||||||
config_path = tmp_path / "config.yml"
|
|
||||||
output_dir = tmp_path / "model-out"
|
|
||||||
|
|
||||||
test_config = valid_test_config.replace(
|
result = cli_runner.invoke(
|
||||||
"output_dir: model-out", f"output_dir: {output_dir}"
|
cli,
|
||||||
)
|
[
|
||||||
config_path.write_text(test_config)
|
"train",
|
||||||
|
str(config_path),
|
||||||
|
"--learning-rate",
|
||||||
|
"1e-4",
|
||||||
|
"--micro-batch-size",
|
||||||
|
"2",
|
||||||
|
"--no-accelerate",
|
||||||
|
],
|
||||||
|
catch_exceptions=False,
|
||||||
|
)
|
||||||
|
|
||||||
with patch("axolotl.cli.train.train") as mock_train:
|
assert result.exit_code == 0
|
||||||
mock_train.return_value = (MagicMock(), MagicMock())
|
mock_train.assert_called_once()
|
||||||
|
cfg = mock_train.call_args[1]["cfg"]
|
||||||
result = cli_runner.invoke(
|
assert cfg["learning_rate"] == 1e-4
|
||||||
cli,
|
assert cfg["micro_batch_size"] == 2
|
||||||
[
|
|
||||||
"train",
|
|
||||||
str(config_path),
|
|
||||||
"--learning-rate",
|
|
||||||
"1e-4",
|
|
||||||
"--micro-batch-size",
|
|
||||||
"2",
|
|
||||||
"--no-accelerate",
|
|
||||||
],
|
|
||||||
catch_exceptions=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
mock_train.assert_called_once()
|
|
||||||
cfg = mock_train.call_args[1]["cfg"]
|
|
||||||
assert cfg["learning_rate"] == 1e-4
|
|
||||||
assert cfg["micro_batch_size"] == 2
|
|
||||||
|
|||||||
Reference in New Issue
Block a user