CLI Implementation with Click (#2107)
* Initial CLI implementation with click package * Adding fetch command for pulling examples and deepspeed configs * Automating default options for CliArgs classes * Mimicking existing no config behavior * bugfix in choose_config * Updating fetch to sync instead of re-download * bugfix * isort fix * fixing yaml isort order * pre-commit fixes * simplifying argument parsing -- pass through kwargs to do_cli * make accelerate launch default for non-preprocess commands * fixing arg handling * testing None placeholder approach * removing hacky --use-gpu argument to preprocess command * Adding brief README documentation for CLI * remove (New) * Initial CLI pytest tests * progress on CLI pytest * adding inference CLI tests; cleanup * Refactor train CLI tests to remove various mocking * Major CLI test refator; adding remaining CLI codepath test coverage * pytest fixes * remove integration markers * parallelizing examples, deepspeed config downloads; rename test to match other CLI test naming * moving cli pytest due to isolation issues; cleanup * testing fixes; various minor improvements * fix * tests fix * Update tests/cli/conftest.py Co-authored-by: Wing Lian <wing.lian@gmail.com> --------- Co-authored-by: Dan Saunders <dan@axolotl.ai> Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
0
tests/cli/__init__.py
Normal file
0
tests/cli/__init__.py
Normal file
36
tests/cli/conftest.py
Normal file
36
tests/cli/conftest.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Shared pytest fixtures for cli module."""
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
VALID_TEST_CONFIG = """
|
||||
base_model: HuggingFaceTB/SmolLM2-135M
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
sequence_len: 2048
|
||||
max_steps: 1
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
learning_rate: 1e-3
|
||||
special_tokens:
|
||||
pad_token: <|endoftext|>
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cli_runner():
|
||||
return CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_test_config():
|
||||
return VALID_TEST_CONFIG
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_path(tmp_path):
|
||||
"""Creates a temporary config file"""
|
||||
path = tmp_path / "config.yml"
|
||||
path.write_text(VALID_TEST_CONFIG)
|
||||
|
||||
return path
|
||||
38
tests/cli/test_cli_fetch.py
Normal file
38
tests/cli/test_cli_fetch.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""pytest tests for axolotl CLI fetch command."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import fetch
|
||||
|
||||
|
||||
def test_fetch_cli_examples(cli_runner):
|
||||
"""Test fetch command with examples directory"""
|
||||
with patch("axolotl.cli.main.fetch_from_github") as mock_fetch:
|
||||
result = cli_runner.invoke(fetch, ["examples"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_fetch.assert_called_once_with("examples/", None)
|
||||
|
||||
|
||||
def test_fetch_cli_deepspeed(cli_runner):
|
||||
"""Test fetch command with deepspeed_configs directory"""
|
||||
with patch("axolotl.cli.main.fetch_from_github") as mock_fetch:
|
||||
result = cli_runner.invoke(fetch, ["deepspeed_configs"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_fetch.assert_called_once_with("deepspeed_configs/", None)
|
||||
|
||||
|
||||
def test_fetch_cli_with_dest(cli_runner, tmp_path):
|
||||
"""Test fetch command with custom destination"""
|
||||
with patch("axolotl.cli.main.fetch_from_github") as mock_fetch:
|
||||
custom_dir = tmp_path / "tmp_examples"
|
||||
result = cli_runner.invoke(fetch, ["examples", "--dest", str(custom_dir)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_fetch.assert_called_once_with("examples/", str(custom_dir))
|
||||
|
||||
|
||||
def test_fetch_cli_invalid_directory(cli_runner):
|
||||
"""Test fetch command with invalid directory choice"""
|
||||
result = cli_runner.invoke(fetch, ["invalid"])
|
||||
assert result.exit_code != 0
|
||||
30
tests/cli/test_cli_inference.py
Normal file
30
tests/cli/test_cli_inference.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""pytest tests for axolotl CLI inference command."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
|
||||
def test_inference_basic(cli_runner, config_path):
|
||||
"""Test basic inference"""
|
||||
with patch("axolotl.cli.inference.do_inference") as mock:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
["inference", str(config_path), "--no-accelerate"],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert mock.called
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_inference_gradio(cli_runner, config_path):
|
||||
"""Test basic inference (gradio path)"""
|
||||
with patch("axolotl.cli.inference.do_inference_gradio") as mock:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
["inference", str(config_path), "--no-accelerate", "--gradio"],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert mock.called
|
||||
assert result.exit_code == 0
|
||||
47
tests/cli/test_cli_interface.py
Normal file
47
tests/cli/test_cli_interface.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""General pytest tests for axolotl.cli.main interface."""
|
||||
from axolotl.cli.main import build_command, cli
|
||||
|
||||
|
||||
def test_build_command():
|
||||
"""Test converting dict of options to CLI arguments"""
|
||||
base_cmd = ["accelerate", "launch"]
|
||||
options = {
|
||||
"learning_rate": 1e-4,
|
||||
"batch_size": 8,
|
||||
"debug": True,
|
||||
"use_fp16": False,
|
||||
"null_value": None,
|
||||
}
|
||||
|
||||
result = build_command(base_cmd, options)
|
||||
assert result == [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--learning-rate",
|
||||
"0.0001",
|
||||
"--batch-size",
|
||||
"8",
|
||||
"--debug",
|
||||
]
|
||||
|
||||
|
||||
def test_invalid_command_options(cli_runner):
|
||||
"""Test handling of invalid command options"""
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
"config.yml",
|
||||
"--invalid-option",
|
||||
"value",
|
||||
],
|
||||
)
|
||||
assert result.exit_code != 0
|
||||
assert "No such option" in result.output
|
||||
|
||||
|
||||
def test_required_config_argument(cli_runner):
|
||||
"""Test commands fail properly when config argument is missing"""
|
||||
result = cli_runner.invoke(cli, ["train"])
|
||||
assert result.exit_code != 0
|
||||
assert "Missing argument 'CONFIG'" in result.output
|
||||
56
tests/cli/test_cli_merge_lora.py
Normal file
56
tests/cli/test_cli_merge_lora.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""pytest tests for axolotl CLI merge_lora command."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
|
||||
def test_merge_lora_basic(cli_runner, config_path):
|
||||
"""Test basic merge_lora command"""
|
||||
with patch("axolotl.cli.merge_lora.do_cli") as mock_do_cli:
|
||||
result = cli_runner.invoke(cli, ["merge-lora", str(config_path)])
|
||||
assert result.exit_code == 0
|
||||
|
||||
mock_do_cli.assert_called_once()
|
||||
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||
|
||||
|
||||
def test_merge_lora_with_dirs(cli_runner, config_path, tmp_path):
|
||||
"""Test merge_lora with custom lora and output directories"""
|
||||
lora_dir = tmp_path / "lora"
|
||||
output_dir = tmp_path / "output"
|
||||
lora_dir.mkdir()
|
||||
|
||||
with patch("axolotl.cli.merge_lora.do_cli") as mock_do_cli:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"merge-lora",
|
||||
str(config_path),
|
||||
"--lora-model-dir",
|
||||
str(lora_dir),
|
||||
"--output-dir",
|
||||
str(output_dir),
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
|
||||
mock_do_cli.assert_called_once()
|
||||
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||
assert mock_do_cli.call_args.kwargs["lora_model_dir"] == str(lora_dir)
|
||||
assert mock_do_cli.call_args.kwargs["output_dir"] == str(output_dir)
|
||||
|
||||
|
||||
def test_merge_lora_nonexistent_config(cli_runner, tmp_path):
|
||||
"""Test merge_lora with nonexistent config"""
|
||||
config_path = tmp_path / "nonexistent.yml"
|
||||
result = cli_runner.invoke(cli, ["merge-lora", str(config_path)])
|
||||
assert result.exit_code != 0
|
||||
|
||||
|
||||
def test_merge_lora_nonexistent_lora_dir(cli_runner, config_path, tmp_path):
|
||||
"""Test merge_lora with nonexistent lora directory"""
|
||||
lora_dir = tmp_path / "nonexistent"
|
||||
result = cli_runner.invoke(
|
||||
cli, ["merge-lora", str(config_path), "--lora-model-dir", str(lora_dir)]
|
||||
)
|
||||
assert result.exit_code != 0
|
||||
60
tests/cli/test_cli_merge_sharded_fsdp_weights.py
Normal file
60
tests/cli/test_cli_merge_sharded_fsdp_weights.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
|
||||
# pylint: disable=duplicate-code
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
|
||||
def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):
|
||||
"""Test merge_sharded_fsdp_weights command without accelerate"""
|
||||
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
|
||||
result = cli_runner.invoke(
|
||||
cli, ["merge-sharded-fsdp-weights", str(config_path), "--no-accelerate"]
|
||||
)
|
||||
|
||||
assert mock.called
|
||||
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_merge_sharded_fsdp_weights_with_model_dir(cli_runner, config_path, tmp_path):
|
||||
"""Test merge_sharded_fsdp_weights command with model_dir option"""
|
||||
model_dir = tmp_path / "model"
|
||||
model_dir.mkdir()
|
||||
|
||||
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"merge-sharded-fsdp-weights",
|
||||
str(config_path),
|
||||
"--no-accelerate",
|
||||
"--model-dir",
|
||||
str(model_dir),
|
||||
],
|
||||
)
|
||||
|
||||
assert mock.called
|
||||
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_merge_sharded_fsdp_weights_with_save_path(cli_runner, config_path):
|
||||
"""Test merge_sharded_fsdp_weights command with save_path option"""
|
||||
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"merge-sharded-fsdp-weights",
|
||||
str(config_path),
|
||||
"--no-accelerate",
|
||||
"--save-path",
|
||||
"/path/to/save",
|
||||
],
|
||||
)
|
||||
|
||||
assert mock.called
|
||||
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||
assert mock.call_args.kwargs["save_path"] == "/path/to/save"
|
||||
assert result.exit_code == 0
|
||||
71
tests/cli/test_cli_preprocess.py
Normal file
71
tests/cli/test_cli_preprocess.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""pytest tests for axolotl CLI preprocess command."""
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_last_run_prepared():
|
||||
yield
|
||||
|
||||
if Path("last_run_prepared").exists():
|
||||
shutil.rmtree("last_run_prepared")
|
||||
|
||||
|
||||
def test_preprocess_config_not_found(cli_runner):
|
||||
"""Test preprocess fails when config not found"""
|
||||
result = cli_runner.invoke(cli, ["preprocess", "nonexistent.yml"])
|
||||
assert result.exit_code != 0
|
||||
|
||||
|
||||
def test_preprocess_basic(cli_runner, config_path):
|
||||
"""Test basic preprocessing with minimal config"""
|
||||
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
|
||||
result = cli_runner.invoke(cli, ["preprocess", str(config_path)])
|
||||
assert result.exit_code == 0
|
||||
|
||||
mock_do_cli.assert_called_once()
|
||||
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||
assert mock_do_cli.call_args.kwargs["download"] is True
|
||||
|
||||
|
||||
def test_preprocess_without_download(cli_runner, config_path):
|
||||
"""Test preprocessing without model download"""
|
||||
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
|
||||
result = cli_runner.invoke(
|
||||
cli, ["preprocess", str(config_path), "--no-download"]
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
|
||||
mock_do_cli.assert_called_once()
|
||||
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||
assert mock_do_cli.call_args.kwargs["download"] is False
|
||||
|
||||
|
||||
def test_preprocess_custom_path(cli_runner, tmp_path, valid_test_config):
|
||||
"""Test preprocessing with custom dataset path"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
custom_path = tmp_path / "custom_prepared"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"preprocess",
|
||||
str(config_path),
|
||||
"--dataset-prepared-path",
|
||||
str(custom_path.absolute()),
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
|
||||
mock_do_cli.assert_called_once()
|
||||
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||
assert mock_do_cli.call_args.kwargs["dataset_prepared_path"] == str(
|
||||
custom_path.absolute()
|
||||
)
|
||||
76
tests/cli/test_cli_shard.py
Normal file
76
tests/cli/test_cli_shard.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""pytest tests for axolotl CLI shard command."""
|
||||
# pylint: disable=duplicate-code
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
|
||||
def test_shard_with_accelerate(cli_runner, config_path):
|
||||
"""Test shard command with accelerate"""
|
||||
with patch("subprocess.run") as mock:
|
||||
result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"])
|
||||
|
||||
assert mock.called
|
||||
assert mock.call_args.args[0] == [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"-m",
|
||||
"axolotl.cli.shard",
|
||||
str(config_path),
|
||||
"--debug-num-examples",
|
||||
"0",
|
||||
]
|
||||
assert mock.call_args.kwargs == {"check": True}
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_shard_no_accelerate(cli_runner, config_path):
|
||||
"""Test shard command without accelerate"""
|
||||
with patch("axolotl.cli.shard.do_cli") as mock:
|
||||
result = cli_runner.invoke(cli, ["shard", str(config_path), "--no-accelerate"])
|
||||
|
||||
assert mock.called
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_shard_with_model_dir(cli_runner, config_path, tmp_path):
|
||||
"""Test shard command with model_dir option"""
|
||||
model_dir = tmp_path / "model"
|
||||
model_dir.mkdir()
|
||||
|
||||
with patch("axolotl.cli.shard.do_cli") as mock:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"shard",
|
||||
str(config_path),
|
||||
"--no-accelerate",
|
||||
"--model-dir",
|
||||
str(model_dir),
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert mock.called
|
||||
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_shard_with_save_dir(cli_runner, config_path):
|
||||
with patch("axolotl.cli.shard.do_cli") as mock:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"shard",
|
||||
str(config_path),
|
||||
"--no-accelerate",
|
||||
"--save-dir",
|
||||
"/path/to/save",
|
||||
],
|
||||
)
|
||||
|
||||
assert mock.called
|
||||
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||
assert mock.call_args.kwargs["save_dir"] == "/path/to/save"
|
||||
assert result.exit_code == 0
|
||||
98
tests/cli/test_cli_train.py
Normal file
98
tests/cli/test_cli_train.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""pytest tests for axolotl CLI train command."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
|
||||
def test_train_cli_validation(cli_runner):
|
||||
"""Test CLI validation"""
|
||||
# Test missing config file
|
||||
result = cli_runner.invoke(cli, ["train", "--no-accelerate"])
|
||||
assert result.exit_code != 0
|
||||
|
||||
# Test non-existent config file
|
||||
result = cli_runner.invoke(cli, ["train", "nonexistent.yml", "--no-accelerate"])
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||
|
||||
|
||||
def test_train_basic_execution(cli_runner, tmp_path, valid_test_config):
|
||||
"""Test basic successful execution"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock:
|
||||
result = cli_runner.invoke(cli, ["train", str(config_path)])
|
||||
|
||||
assert mock.called
|
||||
assert mock.call_args.args[0] == [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(config_path),
|
||||
"--debug-num-examples",
|
||||
"0",
|
||||
]
|
||||
assert mock.call_args.kwargs == {"check": True}
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_train_basic_execution_no_accelerate(cli_runner, tmp_path, valid_test_config):
|
||||
"""Test basic successful execution"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("axolotl.cli.train.train") as mock_train:
|
||||
mock_train.return_value = (MagicMock(), MagicMock())
|
||||
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--learning-rate",
|
||||
"1e-4",
|
||||
"--micro-batch-size",
|
||||
"2",
|
||||
"--no-accelerate",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_train.assert_called_once()
|
||||
|
||||
|
||||
def test_train_cli_overrides(cli_runner, tmp_path, valid_test_config):
|
||||
"""Test CLI arguments properly override config values"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
output_dir = tmp_path / "model-out"
|
||||
|
||||
test_config = valid_test_config.replace(
|
||||
"output_dir: model-out", f"output_dir: {output_dir}"
|
||||
)
|
||||
config_path.write_text(test_config)
|
||||
|
||||
with patch("axolotl.cli.train.train") as mock_train:
|
||||
mock_train.return_value = (MagicMock(), MagicMock())
|
||||
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--learning-rate",
|
||||
"1e-4",
|
||||
"--micro-batch-size",
|
||||
"2",
|
||||
"--no-accelerate",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_train.assert_called_once()
|
||||
cfg = mock_train.call_args[1]["cfg"]
|
||||
assert cfg["learning_rate"] == 1e-4
|
||||
assert cfg["micro_batch_size"] == 2
|
||||
89
tests/cli/test_utils.py
Normal file
89
tests/cli/test_utils.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""pytest tests for axolotl CLI utils."""
|
||||
# pylint: disable=redefined-outer-name
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import click
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from axolotl.cli.utils import fetch_from_github
|
||||
|
||||
# Sample GitHub API response
|
||||
MOCK_TREE_RESPONSE = {
|
||||
"tree": [
|
||||
{"path": "examples/config1.yml", "type": "blob", "sha": "abc123"},
|
||||
{"path": "examples/config2.yml", "type": "blob", "sha": "def456"},
|
||||
{"path": "other/file.txt", "type": "blob", "sha": "xyz789"},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_responses():
|
||||
"""Mock responses for API and file downloads"""
|
||||
|
||||
def mock_get(url, timeout=None): # pylint: disable=unused-argument
|
||||
response = Mock()
|
||||
if "api.github.com" in url:
|
||||
response.text = json.dumps(MOCK_TREE_RESPONSE)
|
||||
else:
|
||||
response.content = b"file content"
|
||||
return response
|
||||
|
||||
return mock_get
|
||||
|
||||
|
||||
def test_fetch_from_github_new_files(tmp_path, mock_responses):
|
||||
"""Test fetching new files"""
|
||||
with patch("requests.get", mock_responses):
|
||||
fetch_from_github("examples/", tmp_path)
|
||||
|
||||
# Verify files were created
|
||||
assert (tmp_path / "config1.yml").exists()
|
||||
assert (tmp_path / "config2.yml").exists()
|
||||
assert not (tmp_path / "file.txt").exists()
|
||||
|
||||
|
||||
def test_fetch_from_github_unchanged_files(tmp_path, mock_responses):
|
||||
"""Test handling of unchanged files"""
|
||||
# Create existing file with matching SHA
|
||||
existing_file = tmp_path / "config1.yml"
|
||||
existing_file.write_bytes(b"file content")
|
||||
|
||||
with patch("requests.get", mock_responses):
|
||||
fetch_from_github("examples/", tmp_path)
|
||||
|
||||
# File should not be downloaded again
|
||||
assert existing_file.read_bytes() == b"file content"
|
||||
|
||||
|
||||
def test_fetch_from_github_invalid_prefix(mock_responses):
|
||||
"""Test error handling for invalid directory prefix"""
|
||||
with patch("requests.get", mock_responses):
|
||||
with pytest.raises(click.ClickException):
|
||||
fetch_from_github("nonexistent/", None)
|
||||
|
||||
|
||||
def test_fetch_from_github_network_error():
|
||||
"""Test handling of network errors"""
|
||||
with patch("requests.get", side_effect=requests.RequestException):
|
||||
with pytest.raises(requests.RequestException):
|
||||
fetch_from_github("examples/", None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def integration_test_dir(tmp_path):
|
||||
"""Fixture for integration test directory that cleans up after itself"""
|
||||
test_dir = tmp_path / "github_downloads"
|
||||
test_dir.mkdir(parents=True)
|
||||
yield test_dir
|
||||
|
||||
|
||||
def test_fetch_from_github_real(integration_test_dir):
|
||||
"""Test actual GitHub API interaction"""
|
||||
fetch_from_github("examples/", integration_test_dir)
|
||||
|
||||
# Verify some known files exist
|
||||
assert (integration_test_dir / "openllama-3b" / "lora.yml").exists()
|
||||
assert (integration_test_dir / "openllama-3b" / "qlora.yml").exists()
|
||||
Reference in New Issue
Block a user