Files
axolotl/tests/cli/test_cli_base.py
Dan Saunders bb1cae1a20 CLI: add --launcher option, support launcher args, cleanup, refactor (#2924)
* add --launcher option; explicit True/False bool args; small cleanup

* refactor

* add torchrun, accelerate cli args

* add rdzv arg default + tests

* update _quarto

* coderabbit

* fix

* we can't set rdvz_id independently across nodes

* coderabbit

* fix tests
2025-07-30 15:46:56 -04:00

88 lines
2.7 KiB
Python

"""Base test class for CLI commands."""
from pathlib import Path
from unittest.mock import patch
from axolotl.cli.main import cli
class BaseCliTest:
"""Base class for CLI command tests."""
def _test_cli_validation(self, cli_runner, command: str):
"""Test CLI validation for a command.
Args:
cli_runner: CLI runner fixture
command: Command to test (train/evaluate)
"""
# Test missing config file
result = cli_runner.invoke(cli, [command, "--launcher", "python"])
assert result.exit_code != 0
# Test non-existent config file
result = cli_runner.invoke(
cli, [command, "nonexistent.yml", "--launcher", "python"]
)
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
def _test_basic_execution(
self,
cli_runner,
tmp_path: Path,
valid_test_config: str,
command: str,
train: bool = True,
):
"""Test basic execution with accelerate.
Args:
cli_runner: CLI runner fixture
tmp_path: Temporary path fixture
valid_test_config: Valid config fixture
command: Command to test (train/evaluate)
train: Whether to test training (default) or evaluation
"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock:
result = cli_runner.invoke(cli, [command, str(config_path)])
assert mock.called
expected = [
"accelerate",
"launch",
"-m",
f"axolotl.cli.{command}",
str(config_path),
"--debug=False",
"--debug-text-only=False",
"--debug-num-examples=0",
]
if train:
expected.append("--shard=False")
assert mock.call_args.args[0] == expected
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0
def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):
"""Test CLI argument overrides.
Args:
tmp_path: Temporary path fixture
valid_test_config: Valid config fixture
command: Command to test (train/evaluate)
"""
config_path = tmp_path / "config.yml"
output_dir = tmp_path / "model-out"
test_config = valid_test_config.replace(
"output_dir: model-out", f"output_dir: {output_dir}"
)
config_path.write_text(test_config)
return config_path