Basic evaluate CLI command / codepath (#2188)
* basic evaluate CLI command / codepath * tests for evaluate CLI command * fixes and cleanup * review comments; slightly DRYing up things --------- Co-authored-by: Dan Saunders <dan@axolotl.ai>
This commit is contained in:
52
src/axolotl/cli/evaluate.py
Normal file
52
src/axolotl/cli/evaluate.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
check_user_token,
|
||||
load_cfg,
|
||||
load_datasets,
|
||||
load_rl_datasets,
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.evaluate import evaluate
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.evaluate")
|
||||
|
||||
|
||||
def do_evaluate(cfg, cli_args) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
|
||||
if cfg.rl: # and cfg.rl != "orpo":
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parser = HfArgumentParser(TrainerCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
do_evaluate(parsed_cfg, parsed_cli_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
@@ -12,7 +12,7 @@ from axolotl.cli.utils import (
|
||||
build_command,
|
||||
fetch_from_github,
|
||||
)
|
||||
from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
|
||||
@@ -60,6 +60,31 @@ def train(config: str, accelerate: bool, **kwargs):
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=True,
|
||||
help="Use accelerate launch for multi-GPU training",
|
||||
)
|
||||
@add_options_from_dataclass(EvaluateCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
def evaluate(config: str, accelerate: bool, **kwargs):
|
||||
"""Evaluate a model."""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
if accelerate:
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
from axolotl.cli.evaluate import do_cli
|
||||
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
|
||||
@@ -15,6 +15,19 @@ configure_logging()
|
||||
LOG = logging.getLogger("axolotl.common.cli")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessCliArgs:
|
||||
"""
|
||||
dataclass representing arguments for preprocessing only
|
||||
"""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=1)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
download: Optional[bool] = field(default=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainerCliArgs:
|
||||
"""
|
||||
@@ -31,16 +44,14 @@ class TrainerCliArgs:
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessCliArgs:
|
||||
class EvaluateCliArgs:
|
||||
"""
|
||||
dataclass representing arguments for preprocessing only
|
||||
dataclass representing the various evaluation arguments
|
||||
"""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=1)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
download: Optional[bool] = field(default=True)
|
||||
debug_num_examples: int = field(default=0)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
@@ -50,7 +61,9 @@ def load_model_and_tokenizer(
|
||||
):
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
LOG.info("loading model and (optionally) peft_config...")
|
||||
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||
inference = getattr(cli_args, "inference", False)
|
||||
model, _ = load_model(cfg, tokenizer, inference=inference)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
168
src/axolotl/evaluate.py
Normal file
168
src/axolotl/evaluate.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Module for evaluating models."""
|
||||
|
||||
import csv
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||
from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
configure_logging()
|
||||
LOG = get_logger("axolotl.evaluate")
|
||||
|
||||
|
||||
def evaluate_dataset(
|
||||
trainer, dataset, dataset_type: str, flash_optimum: bool = False
|
||||
) -> Optional[Dict[str, float]]:
|
||||
"""Helper function to evaluate a single dataset safely.
|
||||
|
||||
Args:
|
||||
trainer: The trainer instance
|
||||
dataset: Dataset to evaluate
|
||||
dataset_type: Type of dataset ('train' or 'eval')
|
||||
flash_optimum: Whether to use flash optimum
|
||||
|
||||
Returns:
|
||||
Dictionary of metrics or None if dataset is None
|
||||
"""
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
LOG.info(f"Starting {dataset_type} set evaluation...")
|
||||
|
||||
if flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=True,
|
||||
):
|
||||
metrics = trainer.evaluate(dataset, metric_key_prefix=dataset_type)
|
||||
else:
|
||||
metrics = trainer.evaluate(dataset, metric_key_prefix=dataset_type)
|
||||
|
||||
LOG.info(f"{dataset_type.capitalize()} set evaluation completed!")
|
||||
LOG.info(f"{dataset_type.capitalize()} Metrics:")
|
||||
for key, value in metrics.items():
|
||||
LOG.info(f"{key}: {value}")
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def evaluate(
|
||||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate a model on training and validation datasets
|
||||
|
||||
Args:
|
||||
cfg: Configuration dictionary
|
||||
cli_args: Command line arguments
|
||||
dataset_meta: Dataset metadata containing training and evaluation datasets
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- The model (either PeftModel or PreTrainedModel)
|
||||
- The tokenizer
|
||||
- Dictionary of evaluation metrics
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
# Load tokenizer
|
||||
LOG.debug(
|
||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||
main_process_only=True,
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
# Load processor for multimodal models if needed
|
||||
processor = None
|
||||
if cfg.is_multimodal:
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
# Get datasets
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
eval_dataset = dataset_meta.eval_dataset
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
|
||||
# Load model
|
||||
LOG.debug("loading model for evaluation...")
|
||||
model, _ = load_model(
|
||||
cfg, tokenizer, processor=processor, inference=cli_args.inference
|
||||
)
|
||||
|
||||
# Set up trainer
|
||||
trainer = setup_trainer(
|
||||
cfg,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
model=(model, None, None), # No need for model_ref or peft_config
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
total_num_steps=total_num_steps,
|
||||
)
|
||||
|
||||
# Evaluate datasets
|
||||
all_metrics = {}
|
||||
train_metrics = evaluate_dataset(trainer, train_dataset, "train", cfg.flash_optimum)
|
||||
eval_metrics = evaluate_dataset(trainer, eval_dataset, "eval", cfg.flash_optimum)
|
||||
|
||||
if train_metrics:
|
||||
all_metrics.update(train_metrics)
|
||||
if eval_metrics:
|
||||
all_metrics.update(eval_metrics)
|
||||
|
||||
# Save metrics to CSV if output directory is specified and we have metrics
|
||||
if cfg.output_dir and (train_metrics or eval_metrics):
|
||||
output_dir = Path(cfg.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
metrics_file = output_dir / "eval_summary.csv"
|
||||
with metrics_file.open("w", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
writer.writerow(["metric", "training", "validation"])
|
||||
|
||||
# Get unique metric names (removing prefixes) from available metrics
|
||||
train_metric_names = {
|
||||
k.replace("train_", ""): k for k in (train_metrics or {})
|
||||
}
|
||||
eval_metric_names = {
|
||||
k.replace("eval_", ""): k for k in (eval_metrics or {})
|
||||
}
|
||||
all_metric_names = sorted(
|
||||
set(train_metric_names.keys()) | set(eval_metric_names.keys())
|
||||
)
|
||||
|
||||
for metric_name in all_metric_names:
|
||||
train_value = (
|
||||
train_metrics.get(train_metric_names.get(metric_name, ""), "")
|
||||
if train_metrics
|
||||
else ""
|
||||
)
|
||||
eval_value = (
|
||||
eval_metrics.get(eval_metric_names.get(metric_name, ""), "")
|
||||
if eval_metrics
|
||||
else ""
|
||||
)
|
||||
writer.writerow([metric_name, train_value, eval_value])
|
||||
|
||||
LOG.info(f"Evaluation results saved to {metrics_file}")
|
||||
|
||||
del model
|
||||
del tokenizer
|
||||
|
||||
return all_metrics
|
||||
@@ -24,7 +24,7 @@ from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer
|
||||
|
||||
try:
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
@@ -53,25 +53,22 @@ class TrainDatasetMeta:
|
||||
def train(
|
||||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
||||
# enable expandable segments for cuda allocation to improve VRAM usage
|
||||
torch_version = torch.__version__.split(".")
|
||||
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
||||
if torch_major == 2 and torch_minor >= 2:
|
||||
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
|
||||
os.environ[
|
||||
"PYTORCH_CUDA_ALLOC_CONF"
|
||||
] = "expandable_segments:True,roundup_power2_divisions:16"
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
# load the tokenizer first
|
||||
# Load tokenizer
|
||||
LOG.debug(
|
||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||
main_process_only=True,
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
# Load processor for multimodal models if needed
|
||||
processor = None
|
||||
if cfg.is_multimodal:
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
# Get datasets
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
eval_dataset = dataset_meta.eval_dataset
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
|
||||
@@ -119,7 +119,9 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
eval_dataset = None
|
||||
if cfg.dataset_exact_deduplication:
|
||||
LOG.info("Deduplication not available for pretrained datasets")
|
||||
|
||||
return train_dataset, eval_dataset, cfg.max_steps, prompters
|
||||
|
||||
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||
if total_eval_steps == 0:
|
||||
@@ -134,6 +136,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||
else:
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
||||
|
||||
return train_dataset, eval_dataset, total_num_steps, prompters
|
||||
|
||||
|
||||
|
||||
@@ -512,6 +512,17 @@ def prepare_opinionated_env(cfg):
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
def set_pytorch_cuda_alloc_conf():
|
||||
"""Set up CUDA allocation config if using PyTorch >= 2.2"""
|
||||
torch_version = torch.__version__.split(".")
|
||||
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
||||
if torch_major == 2 and torch_minor >= 2:
|
||||
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
|
||||
os.environ[
|
||||
"PYTORCH_CUDA_ALLOC_CONF"
|
||||
] = "expandable_segments:True,roundup_power2_divisions:16"
|
||||
|
||||
|
||||
def setup_trainer(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
||||
):
|
||||
|
||||
73
tests/cli/test_cli_base.py
Normal file
73
tests/cli/test_cli_base.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Base test class for CLI commands."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
|
||||
class BaseCliTest:
|
||||
"""Base class for CLI command tests."""
|
||||
|
||||
def _test_cli_validation(self, cli_runner, command: str):
|
||||
"""Test CLI validation for a command.
|
||||
|
||||
Args:
|
||||
cli_runner: CLI runner fixture
|
||||
command: Command to test (train/evaluate)
|
||||
"""
|
||||
# Test missing config file
|
||||
result = cli_runner.invoke(cli, [command, "--no-accelerate"])
|
||||
assert result.exit_code != 0
|
||||
|
||||
# Test non-existent config file
|
||||
result = cli_runner.invoke(cli, [command, "nonexistent.yml", "--no-accelerate"])
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||
|
||||
def _test_basic_execution(
|
||||
self, cli_runner, tmp_path: Path, valid_test_config: str, command: str
|
||||
):
|
||||
"""Test basic execution with accelerate.
|
||||
|
||||
Args:
|
||||
cli_runner: CLI runner fixture
|
||||
tmp_path: Temporary path fixture
|
||||
valid_test_config: Valid config fixture
|
||||
command: Command to test (train/evaluate)
|
||||
"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock:
|
||||
result = cli_runner.invoke(cli, [command, str(config_path)])
|
||||
|
||||
assert mock.called
|
||||
assert mock.call_args.args[0] == [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"-m",
|
||||
f"axolotl.cli.{command}",
|
||||
str(config_path),
|
||||
"--debug-num-examples",
|
||||
"0",
|
||||
]
|
||||
assert mock.call_args.kwargs == {"check": True}
|
||||
assert result.exit_code == 0
|
||||
|
||||
def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):
|
||||
"""Test CLI argument overrides.
|
||||
|
||||
Args:
|
||||
tmp_path: Temporary path fixture
|
||||
valid_test_config: Valid config fixture
|
||||
command: Command to test (train/evaluate)
|
||||
"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
output_dir = tmp_path / "model-out"
|
||||
|
||||
test_config = valid_test_config.replace(
|
||||
"output_dir: model-out", f"output_dir: {output_dir}"
|
||||
)
|
||||
config_path.write_text(test_config)
|
||||
return config_path
|
||||
67
tests/cli/test_cli_evaluate.py
Normal file
67
tests/cli/test_cli_evaluate.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Tests for evaluate CLI command."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
from .test_cli_base import BaseCliTest
|
||||
|
||||
|
||||
class TestEvaluateCommand(BaseCliTest):
|
||||
"""Test cases for evaluate command."""
|
||||
|
||||
cli = cli
|
||||
|
||||
def test_evaluate_cli_validation(self, cli_runner):
|
||||
"""Test CLI validation"""
|
||||
self._test_cli_validation(cli_runner, "evaluate")
|
||||
|
||||
def test_evaluate_basic_execution(self, cli_runner, tmp_path, valid_test_config):
|
||||
"""Test basic successful execution"""
|
||||
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "evaluate")
|
||||
|
||||
def test_evaluate_basic_execution_no_accelerate(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test basic successful execution without accelerate"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"evaluate",
|
||||
str(config_path),
|
||||
"--no-accelerate",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_evaluate.assert_called_once()
|
||||
|
||||
def test_evaluate_cli_overrides(self, cli_runner, tmp_path, valid_test_config):
|
||||
"""Test CLI arguments properly override config values"""
|
||||
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
|
||||
|
||||
with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"evaluate",
|
||||
str(config_path),
|
||||
"--micro-batch-size",
|
||||
"2",
|
||||
"--sequence-len",
|
||||
"128",
|
||||
"--no-accelerate",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_evaluate.assert_called_once()
|
||||
cfg = mock_evaluate.call_args[0][0]
|
||||
assert cfg.micro_batch_size == 2
|
||||
assert cfg.sequence_len == 128
|
||||
@@ -1,98 +1,71 @@
|
||||
"""pytest tests for axolotl CLI train command."""
|
||||
"""Tests for train CLI command."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
|
||||
def test_train_cli_validation(cli_runner):
|
||||
"""Test CLI validation"""
|
||||
# Test missing config file
|
||||
result = cli_runner.invoke(cli, ["train", "--no-accelerate"])
|
||||
assert result.exit_code != 0
|
||||
|
||||
# Test non-existent config file
|
||||
result = cli_runner.invoke(cli, ["train", "nonexistent.yml", "--no-accelerate"])
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||
from .test_cli_base import BaseCliTest
|
||||
|
||||
|
||||
def test_train_basic_execution(cli_runner, tmp_path, valid_test_config):
|
||||
"""Test basic successful execution"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
class TestTrainCommand(BaseCliTest):
|
||||
"""Test cases for train command."""
|
||||
|
||||
with patch("subprocess.run") as mock:
|
||||
result = cli_runner.invoke(cli, ["train", str(config_path)])
|
||||
cli = cli
|
||||
|
||||
assert mock.called
|
||||
assert mock.call_args.args[0] == [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(config_path),
|
||||
"--debug-num-examples",
|
||||
"0",
|
||||
]
|
||||
assert mock.call_args.kwargs == {"check": True}
|
||||
assert result.exit_code == 0
|
||||
def test_train_cli_validation(self, cli_runner):
|
||||
"""Test CLI validation"""
|
||||
self._test_cli_validation(cli_runner, "train")
|
||||
|
||||
def test_train_basic_execution(self, cli_runner, tmp_path, valid_test_config):
|
||||
"""Test basic successful execution"""
|
||||
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "train")
|
||||
|
||||
def test_train_basic_execution_no_accelerate(cli_runner, tmp_path, valid_test_config):
|
||||
"""Test basic successful execution"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
def test_train_basic_execution_no_accelerate(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test basic successful execution without accelerate"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("axolotl.cli.train.train") as mock_train:
|
||||
mock_train.return_value = (MagicMock(), MagicMock())
|
||||
with patch("axolotl.cli.train.train") as mock_train:
|
||||
mock_train.return_value = (MagicMock(), MagicMock())
|
||||
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--learning-rate",
|
||||
"1e-4",
|
||||
"--micro-batch-size",
|
||||
"2",
|
||||
"--no-accelerate",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--no-accelerate",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_train.assert_called_once()
|
||||
assert result.exit_code == 0
|
||||
mock_train.assert_called_once()
|
||||
|
||||
def test_train_cli_overrides(self, cli_runner, tmp_path, valid_test_config):
|
||||
"""Test CLI arguments properly override config values"""
|
||||
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
|
||||
|
||||
def test_train_cli_overrides(cli_runner, tmp_path, valid_test_config):
|
||||
"""Test CLI arguments properly override config values"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
output_dir = tmp_path / "model-out"
|
||||
with patch("axolotl.cli.train.train") as mock_train:
|
||||
mock_train.return_value = (MagicMock(), MagicMock())
|
||||
|
||||
test_config = valid_test_config.replace(
|
||||
"output_dir: model-out", f"output_dir: {output_dir}"
|
||||
)
|
||||
config_path.write_text(test_config)
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--learning-rate",
|
||||
"1e-4",
|
||||
"--micro-batch-size",
|
||||
"2",
|
||||
"--no-accelerate",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
with patch("axolotl.cli.train.train") as mock_train:
|
||||
mock_train.return_value = (MagicMock(), MagicMock())
|
||||
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--learning-rate",
|
||||
"1e-4",
|
||||
"--micro-batch-size",
|
||||
"2",
|
||||
"--no-accelerate",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_train.assert_called_once()
|
||||
cfg = mock_train.call_args[1]["cfg"]
|
||||
assert cfg["learning_rate"] == 1e-4
|
||||
assert cfg["micro_batch_size"] == 2
|
||||
assert result.exit_code == 0
|
||||
mock_train.assert_called_once()
|
||||
cfg = mock_train.call_args[1]["cfg"]
|
||||
assert cfg["learning_rate"] == 1e-4
|
||||
assert cfg["micro_batch_size"] == 2
|
||||
|
||||
Reference in New Issue
Block a user