CLI Implementation with Click (#2107)

* Initial CLI implementation with click package

* Adding fetch command for pulling examples and deepspeed configs

* Automating default options for CliArgs classes

* Mimicking existing no config behavior

* bugfix in choose_config

* Updating fetch to sync instead of re-download

* bugfix

* isort fix

* fixing yaml isort order

* pre-commit fixes

* simplifying argument parsing -- pass through kwargs to do_cli

* make accelerate launch default for non-preprocess commands

* fixing arg handling

* testing None placeholder approach

* removing hacky --use-gpu argument to preprocess command

* Adding brief README documentation for CLI

* remove (New)

* Initial CLI pytest tests

* progress on CLI pytest

* adding inference CLI tests; cleanup

* Refactor train CLI tests to remove various mocking

* Major CLI test refator; adding remaining CLI codepath test coverage

* pytest fixes

* remove integration markers

* parallelizing examples, deepspeed config downloads; rename test to match other CLI test naming

* moving cli pytest due to isolation issues; cleanup

* testing fixes; various minor improvements

* fix

* tests fix

* Update tests/cli/conftest.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Dan Saunders
2024-12-05 22:11:48 -05:00
committed by GitHub
parent e399ba533e
commit fc973f4322
25 changed files with 1113 additions and 12 deletions

0
tests/cli/__init__.py Normal file
View File

36
tests/cli/conftest.py Normal file
View File

@@ -0,0 +1,36 @@
"""Shared pytest fixtures for cli module."""
import pytest
from click.testing import CliRunner
VALID_TEST_CONFIG = """
base_model: HuggingFaceTB/SmolLM2-135M
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
sequence_len: 2048
max_steps: 1
micro_batch_size: 1
gradient_accumulation_steps: 1
learning_rate: 1e-3
special_tokens:
pad_token: <|endoftext|>
"""
@pytest.fixture
def cli_runner():
return CliRunner()
@pytest.fixture
def valid_test_config():
return VALID_TEST_CONFIG
@pytest.fixture
def config_path(tmp_path):
"""Creates a temporary config file"""
path = tmp_path / "config.yml"
path.write_text(VALID_TEST_CONFIG)
return path

View File

@@ -0,0 +1,38 @@
"""pytest tests for axolotl CLI fetch command."""
from unittest.mock import patch
from axolotl.cli.main import fetch
def test_fetch_cli_examples(cli_runner):
"""Test fetch command with examples directory"""
with patch("axolotl.cli.main.fetch_from_github") as mock_fetch:
result = cli_runner.invoke(fetch, ["examples"])
assert result.exit_code == 0
mock_fetch.assert_called_once_with("examples/", None)
def test_fetch_cli_deepspeed(cli_runner):
"""Test fetch command with deepspeed_configs directory"""
with patch("axolotl.cli.main.fetch_from_github") as mock_fetch:
result = cli_runner.invoke(fetch, ["deepspeed_configs"])
assert result.exit_code == 0
mock_fetch.assert_called_once_with("deepspeed_configs/", None)
def test_fetch_cli_with_dest(cli_runner, tmp_path):
"""Test fetch command with custom destination"""
with patch("axolotl.cli.main.fetch_from_github") as mock_fetch:
custom_dir = tmp_path / "tmp_examples"
result = cli_runner.invoke(fetch, ["examples", "--dest", str(custom_dir)])
assert result.exit_code == 0
mock_fetch.assert_called_once_with("examples/", str(custom_dir))
def test_fetch_cli_invalid_directory(cli_runner):
"""Test fetch command with invalid directory choice"""
result = cli_runner.invoke(fetch, ["invalid"])
assert result.exit_code != 0

View File

@@ -0,0 +1,30 @@
"""pytest tests for axolotl CLI inference command."""
from unittest.mock import patch
from axolotl.cli.main import cli
def test_inference_basic(cli_runner, config_path):
"""Test basic inference"""
with patch("axolotl.cli.inference.do_inference") as mock:
result = cli_runner.invoke(
cli,
["inference", str(config_path), "--no-accelerate"],
catch_exceptions=False,
)
assert mock.called
assert result.exit_code == 0
def test_inference_gradio(cli_runner, config_path):
"""Test basic inference (gradio path)"""
with patch("axolotl.cli.inference.do_inference_gradio") as mock:
result = cli_runner.invoke(
cli,
["inference", str(config_path), "--no-accelerate", "--gradio"],
catch_exceptions=False,
)
assert mock.called
assert result.exit_code == 0

View File

@@ -0,0 +1,47 @@
"""General pytest tests for axolotl.cli.main interface."""
from axolotl.cli.main import build_command, cli
def test_build_command():
"""Test converting dict of options to CLI arguments"""
base_cmd = ["accelerate", "launch"]
options = {
"learning_rate": 1e-4,
"batch_size": 8,
"debug": True,
"use_fp16": False,
"null_value": None,
}
result = build_command(base_cmd, options)
assert result == [
"accelerate",
"launch",
"--learning-rate",
"0.0001",
"--batch-size",
"8",
"--debug",
]
def test_invalid_command_options(cli_runner):
"""Test handling of invalid command options"""
result = cli_runner.invoke(
cli,
[
"train",
"config.yml",
"--invalid-option",
"value",
],
)
assert result.exit_code != 0
assert "No such option" in result.output
def test_required_config_argument(cli_runner):
"""Test commands fail properly when config argument is missing"""
result = cli_runner.invoke(cli, ["train"])
assert result.exit_code != 0
assert "Missing argument 'CONFIG'" in result.output

View File

@@ -0,0 +1,56 @@
"""pytest tests for axolotl CLI merge_lora command."""
from unittest.mock import patch
from axolotl.cli.main import cli
def test_merge_lora_basic(cli_runner, config_path):
"""Test basic merge_lora command"""
with patch("axolotl.cli.merge_lora.do_cli") as mock_do_cli:
result = cli_runner.invoke(cli, ["merge-lora", str(config_path)])
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
def test_merge_lora_with_dirs(cli_runner, config_path, tmp_path):
"""Test merge_lora with custom lora and output directories"""
lora_dir = tmp_path / "lora"
output_dir = tmp_path / "output"
lora_dir.mkdir()
with patch("axolotl.cli.merge_lora.do_cli") as mock_do_cli:
result = cli_runner.invoke(
cli,
[
"merge-lora",
str(config_path),
"--lora-model-dir",
str(lora_dir),
"--output-dir",
str(output_dir),
],
)
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
assert mock_do_cli.call_args.kwargs["lora_model_dir"] == str(lora_dir)
assert mock_do_cli.call_args.kwargs["output_dir"] == str(output_dir)
def test_merge_lora_nonexistent_config(cli_runner, tmp_path):
"""Test merge_lora with nonexistent config"""
config_path = tmp_path / "nonexistent.yml"
result = cli_runner.invoke(cli, ["merge-lora", str(config_path)])
assert result.exit_code != 0
def test_merge_lora_nonexistent_lora_dir(cli_runner, config_path, tmp_path):
"""Test merge_lora with nonexistent lora directory"""
lora_dir = tmp_path / "nonexistent"
result = cli_runner.invoke(
cli, ["merge-lora", str(config_path), "--lora-model-dir", str(lora_dir)]
)
assert result.exit_code != 0

View File

@@ -0,0 +1,60 @@
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):
"""Test merge_sharded_fsdp_weights command without accelerate"""
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli, ["merge-sharded-fsdp-weights", str(config_path), "--no-accelerate"]
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert result.exit_code == 0
def test_merge_sharded_fsdp_weights_with_model_dir(cli_runner, config_path, tmp_path):
"""Test merge_sharded_fsdp_weights command with model_dir option"""
model_dir = tmp_path / "model"
model_dir.mkdir()
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--no-accelerate",
"--model-dir",
str(model_dir),
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
assert result.exit_code == 0
def test_merge_sharded_fsdp_weights_with_save_path(cli_runner, config_path):
"""Test merge_sharded_fsdp_weights command with save_path option"""
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--no-accelerate",
"--save-path",
"/path/to/save",
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["save_path"] == "/path/to/save"
assert result.exit_code == 0

View File

@@ -0,0 +1,71 @@
"""pytest tests for axolotl CLI preprocess command."""
import shutil
from pathlib import Path
from unittest.mock import patch
import pytest
from axolotl.cli.main import cli
@pytest.fixture(autouse=True)
def cleanup_last_run_prepared():
yield
if Path("last_run_prepared").exists():
shutil.rmtree("last_run_prepared")
def test_preprocess_config_not_found(cli_runner):
"""Test preprocess fails when config not found"""
result = cli_runner.invoke(cli, ["preprocess", "nonexistent.yml"])
assert result.exit_code != 0
def test_preprocess_basic(cli_runner, config_path):
"""Test basic preprocessing with minimal config"""
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
result = cli_runner.invoke(cli, ["preprocess", str(config_path)])
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
assert mock_do_cli.call_args.kwargs["download"] is True
def test_preprocess_without_download(cli_runner, config_path):
"""Test preprocessing without model download"""
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
result = cli_runner.invoke(
cli, ["preprocess", str(config_path), "--no-download"]
)
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
assert mock_do_cli.call_args.kwargs["download"] is False
def test_preprocess_custom_path(cli_runner, tmp_path, valid_test_config):
"""Test preprocessing with custom dataset path"""
config_path = tmp_path / "config.yml"
custom_path = tmp_path / "custom_prepared"
config_path.write_text(valid_test_config)
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
result = cli_runner.invoke(
cli,
[
"preprocess",
str(config_path),
"--dataset-prepared-path",
str(custom_path.absolute()),
],
)
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
assert mock_do_cli.call_args.kwargs["dataset_prepared_path"] == str(
custom_path.absolute()
)

View File

@@ -0,0 +1,76 @@
"""pytest tests for axolotl CLI shard command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
def test_shard_with_accelerate(cli_runner, config_path):
"""Test shard command with accelerate"""
with patch("subprocess.run") as mock:
result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"])
assert mock.called
assert mock.call_args.args[0] == [
"accelerate",
"launch",
"-m",
"axolotl.cli.shard",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0
def test_shard_no_accelerate(cli_runner, config_path):
"""Test shard command without accelerate"""
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(cli, ["shard", str(config_path), "--no-accelerate"])
assert mock.called
assert result.exit_code == 0
def test_shard_with_model_dir(cli_runner, config_path, tmp_path):
"""Test shard command with model_dir option"""
model_dir = tmp_path / "model"
model_dir.mkdir()
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"shard",
str(config_path),
"--no-accelerate",
"--model-dir",
str(model_dir),
],
catch_exceptions=False,
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
assert result.exit_code == 0
def test_shard_with_save_dir(cli_runner, config_path):
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"shard",
str(config_path),
"--no-accelerate",
"--save-dir",
"/path/to/save",
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["save_dir"] == "/path/to/save"
assert result.exit_code == 0

View File

@@ -0,0 +1,98 @@
"""pytest tests for axolotl CLI train command."""
from unittest.mock import MagicMock, patch
from axolotl.cli.main import cli
def test_train_cli_validation(cli_runner):
"""Test CLI validation"""
# Test missing config file
result = cli_runner.invoke(cli, ["train", "--no-accelerate"])
assert result.exit_code != 0
# Test non-existent config file
result = cli_runner.invoke(cli, ["train", "nonexistent.yml", "--no-accelerate"])
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
def test_train_basic_execution(cli_runner, tmp_path, valid_test_config):
"""Test basic successful execution"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock:
result = cli_runner.invoke(cli, ["train", str(config_path)])
assert mock.called
assert mock.call_args.args[0] == [
"accelerate",
"launch",
"-m",
"axolotl.cli.train",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0
def test_train_basic_execution_no_accelerate(cli_runner, tmp_path, valid_test_config):
"""Test basic successful execution"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock())
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--learning-rate",
"1e-4",
"--micro-batch-size",
"2",
"--no-accelerate",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_train.assert_called_once()
def test_train_cli_overrides(cli_runner, tmp_path, valid_test_config):
"""Test CLI arguments properly override config values"""
config_path = tmp_path / "config.yml"
output_dir = tmp_path / "model-out"
test_config = valid_test_config.replace(
"output_dir: model-out", f"output_dir: {output_dir}"
)
config_path.write_text(test_config)
with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock())
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--learning-rate",
"1e-4",
"--micro-batch-size",
"2",
"--no-accelerate",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_train.assert_called_once()
cfg = mock_train.call_args[1]["cfg"]
assert cfg["learning_rate"] == 1e-4
assert cfg["micro_batch_size"] == 2

89
tests/cli/test_utils.py Normal file
View File

@@ -0,0 +1,89 @@
"""pytest tests for axolotl CLI utils."""
# pylint: disable=redefined-outer-name
import json
from unittest.mock import Mock, patch
import click
import pytest
import requests
from axolotl.cli.utils import fetch_from_github
# Sample GitHub API response
MOCK_TREE_RESPONSE = {
"tree": [
{"path": "examples/config1.yml", "type": "blob", "sha": "abc123"},
{"path": "examples/config2.yml", "type": "blob", "sha": "def456"},
{"path": "other/file.txt", "type": "blob", "sha": "xyz789"},
]
}
@pytest.fixture
def mock_responses():
"""Mock responses for API and file downloads"""
def mock_get(url, timeout=None): # pylint: disable=unused-argument
response = Mock()
if "api.github.com" in url:
response.text = json.dumps(MOCK_TREE_RESPONSE)
else:
response.content = b"file content"
return response
return mock_get
def test_fetch_from_github_new_files(tmp_path, mock_responses):
"""Test fetching new files"""
with patch("requests.get", mock_responses):
fetch_from_github("examples/", tmp_path)
# Verify files were created
assert (tmp_path / "config1.yml").exists()
assert (tmp_path / "config2.yml").exists()
assert not (tmp_path / "file.txt").exists()
def test_fetch_from_github_unchanged_files(tmp_path, mock_responses):
"""Test handling of unchanged files"""
# Create existing file with matching SHA
existing_file = tmp_path / "config1.yml"
existing_file.write_bytes(b"file content")
with patch("requests.get", mock_responses):
fetch_from_github("examples/", tmp_path)
# File should not be downloaded again
assert existing_file.read_bytes() == b"file content"
def test_fetch_from_github_invalid_prefix(mock_responses):
"""Test error handling for invalid directory prefix"""
with patch("requests.get", mock_responses):
with pytest.raises(click.ClickException):
fetch_from_github("nonexistent/", None)
def test_fetch_from_github_network_error():
"""Test handling of network errors"""
with patch("requests.get", side_effect=requests.RequestException):
with pytest.raises(requests.RequestException):
fetch_from_github("examples/", None)
@pytest.fixture
def integration_test_dir(tmp_path):
"""Fixture for integration test directory that cleans up after itself"""
test_dir = tmp_path / "github_downloads"
test_dir.mkdir(parents=True)
yield test_dir
def test_fetch_from_github_real(integration_test_dir):
"""Test actual GitHub API interaction"""
fetch_from_github("examples/", integration_test_dir)
# Verify some known files exist
assert (integration_test_dir / "openllama-3b" / "lora.yml").exists()
assert (integration_test_dir / "openllama-3b" / "qlora.yml").exists()