* use exec instead of subprocess to make ctrl+c nicer for cli * change var name to use_exec * simplify to bool * flush std* * patch subprocess as mock in test * fix tests * more test fixes
94 lines
2.9 KiB
Python
94 lines
2.9 KiB
Python
"""Base test class for CLI commands."""
|
|
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
from axolotl.cli.main import cli
|
|
|
|
|
|
class BaseCliTest:
|
|
"""Base class for CLI command tests."""
|
|
|
|
def _test_cli_validation(self, cli_runner, command: str):
|
|
"""Test CLI validation for a command.
|
|
|
|
Args:
|
|
cli_runner: CLI runner fixture
|
|
command: Command to test (train/evaluate)
|
|
"""
|
|
# Test missing config file
|
|
result = cli_runner.invoke(cli, [command, "--launcher", "python"])
|
|
assert result.exit_code != 0
|
|
|
|
# Test non-existent config file
|
|
result = cli_runner.invoke(
|
|
cli, [command, "nonexistent.yml", "--launcher", "python"]
|
|
)
|
|
assert result.exit_code != 0
|
|
assert "Error: Invalid value for 'CONFIG'" in result.output
|
|
|
|
def _test_basic_execution(
|
|
self,
|
|
cli_runner,
|
|
tmp_path: Path,
|
|
valid_test_config: str,
|
|
command: str,
|
|
train: bool = True,
|
|
):
|
|
"""Test basic execution with accelerate.
|
|
|
|
Args:
|
|
cli_runner: CLI runner fixture
|
|
tmp_path: Temporary path fixture
|
|
valid_test_config: Valid config fixture
|
|
command: Command to test (train/evaluate)
|
|
train: Whether to test training (default) or evaluation
|
|
"""
|
|
config_path = tmp_path / "config.yml"
|
|
config_path.write_text(valid_test_config)
|
|
|
|
mock_fn = "os.execvpe" if command == "train" else "subprocess.run"
|
|
|
|
with patch(mock_fn) as mock:
|
|
result = cli_runner.invoke(cli, [command, str(config_path)])
|
|
|
|
assert mock.called
|
|
|
|
expected = [
|
|
"accelerate",
|
|
"launch",
|
|
"-m",
|
|
f"axolotl.cli.{command}",
|
|
str(config_path),
|
|
"--debug=False",
|
|
"--debug-text-only=False",
|
|
"--debug-num-examples=0",
|
|
]
|
|
if train:
|
|
expected.append("--shard=False")
|
|
|
|
if command == "train":
|
|
assert mock.call_args.args[0] == "accelerate"
|
|
assert mock.call_args.args[1] == expected
|
|
else:
|
|
assert mock.call_args.args[0] == expected
|
|
assert mock.call_args.kwargs == {"check": True}
|
|
assert result.exit_code == 0
|
|
|
|
def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):
|
|
"""Test CLI argument overrides.
|
|
|
|
Args:
|
|
tmp_path: Temporary path fixture
|
|
valid_test_config: Valid config fixture
|
|
command: Command to test (train/evaluate)
|
|
"""
|
|
config_path = tmp_path / "config.yml"
|
|
output_dir = tmp_path / "model-out"
|
|
|
|
test_config = valid_test_config.replace(
|
|
"output_dir: model-out", f"output_dir: {output_dir}"
|
|
)
|
|
config_path.write_text(test_config)
|
|
return config_path
|