Basic evaluate CLI command / codepath (#2188)

* basic evaluate CLI command / codepath

* tests for evaluate CLI command

* fixes and cleanup

* review comments; slightly DRYing up things

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
This commit is contained in:
Dan Saunders
2024-12-16 15:46:31 -05:00
committed by GitHub
parent 33090486d7
commit f865464ae5
10 changed files with 480 additions and 98 deletions

View File

@@ -0,0 +1,73 @@
"""Base test class for CLI commands."""
from pathlib import Path
from unittest.mock import patch
from axolotl.cli.main import cli
class BaseCliTest:
"""Base class for CLI command tests."""
def _test_cli_validation(self, cli_runner, command: str):
"""Test CLI validation for a command.
Args:
cli_runner: CLI runner fixture
command: Command to test (train/evaluate)
"""
# Test missing config file
result = cli_runner.invoke(cli, [command, "--no-accelerate"])
assert result.exit_code != 0
# Test non-existent config file
result = cli_runner.invoke(cli, [command, "nonexistent.yml", "--no-accelerate"])
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
def _test_basic_execution(
self, cli_runner, tmp_path: Path, valid_test_config: str, command: str
):
"""Test basic execution with accelerate.
Args:
cli_runner: CLI runner fixture
tmp_path: Temporary path fixture
valid_test_config: Valid config fixture
command: Command to test (train/evaluate)
"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock:
result = cli_runner.invoke(cli, [command, str(config_path)])
assert mock.called
assert mock.call_args.args[0] == [
"accelerate",
"launch",
"-m",
f"axolotl.cli.{command}",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0
def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):
"""Test CLI argument overrides.
Args:
tmp_path: Temporary path fixture
valid_test_config: Valid config fixture
command: Command to test (train/evaluate)
"""
config_path = tmp_path / "config.yml"
output_dir = tmp_path / "model-out"
test_config = valid_test_config.replace(
"output_dir: model-out", f"output_dir: {output_dir}"
)
config_path.write_text(test_config)
return config_path

View File

@@ -0,0 +1,67 @@
"""Tests for evaluate CLI command."""
from unittest.mock import patch
from axolotl.cli.main import cli
from .test_cli_base import BaseCliTest
class TestEvaluateCommand(BaseCliTest):
"""Test cases for evaluate command."""
cli = cli
def test_evaluate_cli_validation(self, cli_runner):
"""Test CLI validation"""
self._test_cli_validation(cli_runner, "evaluate")
def test_evaluate_basic_execution(self, cli_runner, tmp_path, valid_test_config):
"""Test basic successful execution"""
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "evaluate")
def test_evaluate_basic_execution_no_accelerate(
self, cli_runner, tmp_path, valid_test_config
):
"""Test basic successful execution without accelerate"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate:
result = cli_runner.invoke(
cli,
[
"evaluate",
str(config_path),
"--no-accelerate",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_evaluate.assert_called_once()
def test_evaluate_cli_overrides(self, cli_runner, tmp_path, valid_test_config):
"""Test CLI arguments properly override config values"""
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate:
result = cli_runner.invoke(
cli,
[
"evaluate",
str(config_path),
"--micro-batch-size",
"2",
"--sequence-len",
"128",
"--no-accelerate",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_evaluate.assert_called_once()
cfg = mock_evaluate.call_args[0][0]
assert cfg.micro_batch_size == 2
assert cfg.sequence_len == 128

View File

@@ -1,98 +1,71 @@
"""pytest tests for axolotl CLI train command."""
"""Tests for train CLI command."""
from unittest.mock import MagicMock, patch
from axolotl.cli.main import cli
def test_train_cli_validation(cli_runner):
"""Test CLI validation"""
# Test missing config file
result = cli_runner.invoke(cli, ["train", "--no-accelerate"])
assert result.exit_code != 0
# Test non-existent config file
result = cli_runner.invoke(cli, ["train", "nonexistent.yml", "--no-accelerate"])
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
from .test_cli_base import BaseCliTest
def test_train_basic_execution(cli_runner, tmp_path, valid_test_config):
"""Test basic successful execution"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
class TestTrainCommand(BaseCliTest):
"""Test cases for train command."""
with patch("subprocess.run") as mock:
result = cli_runner.invoke(cli, ["train", str(config_path)])
cli = cli
assert mock.called
assert mock.call_args.args[0] == [
"accelerate",
"launch",
"-m",
"axolotl.cli.train",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0
def test_train_cli_validation(self, cli_runner):
"""Test CLI validation"""
self._test_cli_validation(cli_runner, "train")
def test_train_basic_execution(self, cli_runner, tmp_path, valid_test_config):
"""Test basic successful execution"""
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "train")
def test_train_basic_execution_no_accelerate(cli_runner, tmp_path, valid_test_config):
"""Test basic successful execution"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
def test_train_basic_execution_no_accelerate(
self, cli_runner, tmp_path, valid_test_config
):
"""Test basic successful execution without accelerate"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock())
with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock())
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--learning-rate",
"1e-4",
"--micro-batch-size",
"2",
"--no-accelerate",
],
catch_exceptions=False,
)
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--no-accelerate",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_train.assert_called_once()
assert result.exit_code == 0
mock_train.assert_called_once()
def test_train_cli_overrides(self, cli_runner, tmp_path, valid_test_config):
"""Test CLI arguments properly override config values"""
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
def test_train_cli_overrides(cli_runner, tmp_path, valid_test_config):
"""Test CLI arguments properly override config values"""
config_path = tmp_path / "config.yml"
output_dir = tmp_path / "model-out"
with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock())
test_config = valid_test_config.replace(
"output_dir: model-out", f"output_dir: {output_dir}"
)
config_path.write_text(test_config)
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--learning-rate",
"1e-4",
"--micro-batch-size",
"2",
"--no-accelerate",
],
catch_exceptions=False,
)
with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock())
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--learning-rate",
"1e-4",
"--micro-batch-size",
"2",
"--no-accelerate",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_train.assert_called_once()
cfg = mock_train.call_args[1]["cfg"]
assert cfg["learning_rate"] == 1e-4
assert cfg["micro_batch_size"] == 2
assert result.exit_code == 0
mock_train.assert_called_once()
cfg = mock_train.call_args[1]["cfg"]
assert cfg["learning_rate"] == 1e-4
assert cfg["micro_batch_size"] == 2