diff --git a/tests/cli/test_cli_preprocess.py b/tests/cli/test_cli_preprocess.py index e2dd3a6c3..b213e43da 100644 --- a/tests/cli/test_cli_preprocess.py +++ b/tests/cli/test_cli_preprocess.py @@ -2,7 +2,7 @@ import shutil from pathlib import Path -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -26,12 +26,15 @@ def test_preprocess_config_not_found(cli_runner): def test_preprocess_basic(cli_runner, config_path): """Test basic preprocessing with minimal config""" with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli: - result = cli_runner.invoke(cli, ["preprocess", str(config_path)]) - assert result.exit_code == 0 + with patch("axolotl.cli.preprocess.load_datasets") as mock_load_datasets: + mock_load_datasets.return_value = MagicMock() - mock_do_cli.assert_called_once() - assert mock_do_cli.call_args.kwargs["config"] == str(config_path) - assert mock_do_cli.call_args.kwargs["download"] is True + result = cli_runner.invoke(cli, ["preprocess", str(config_path)]) + assert result.exit_code == 0 + + mock_do_cli.assert_called_once() + assert mock_do_cli.call_args.kwargs["config"] == str(config_path) + assert mock_do_cli.call_args.kwargs["download"] is True def test_preprocess_without_download(cli_runner, config_path): @@ -54,19 +57,22 @@ def test_preprocess_custom_path(cli_runner, tmp_path, valid_test_config): config_path.write_text(valid_test_config) with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli: - result = cli_runner.invoke( - cli, - [ - "preprocess", - str(config_path), - "--dataset-prepared-path", - str(custom_path.absolute()), - ], - ) - assert result.exit_code == 0 + with patch("axolotl.cli.preprocess.load_datasets") as mock_load_datasets: + mock_load_datasets.return_value = MagicMock() - mock_do_cli.assert_called_once() - assert mock_do_cli.call_args.kwargs["config"] == str(config_path) - assert mock_do_cli.call_args.kwargs["dataset_prepared_path"] == str( - custom_path.absolute() - ) + result = cli_runner.invoke( + cli, + [ + "preprocess", + str(config_path), + "--dataset-prepared-path", + str(custom_path.absolute()), + ], + ) + assert result.exit_code == 0 + + mock_do_cli.assert_called_once() + assert mock_do_cli.call_args.kwargs["config"] == str(config_path) + assert mock_do_cli.call_args.kwargs["dataset_prepared_path"] == str( + custom_path.absolute() + ) diff --git a/tests/cli/test_cli_train.py b/tests/cli/test_cli_train.py index a51251033..473913599 100644 --- a/tests/cli/test_cli_train.py +++ b/tests/cli/test_cli_train.py @@ -29,19 +29,21 @@ class TestTrainCommand(BaseCliTest): with patch("axolotl.cli.train.train") as mock_train: mock_train.return_value = (MagicMock(), MagicMock(), MagicMock()) + with patch("axolotl.cli.train.load_datasets") as mock_load_datasets: + mock_load_datasets.return_value = MagicMock() - result = cli_runner.invoke( - cli, - [ - "train", - str(config_path), - "--no-accelerate", - ], - catch_exceptions=False, - ) + result = cli_runner.invoke( + cli, + [ + "train", + str(config_path), + "--no-accelerate", + ], + catch_exceptions=False, + ) - assert result.exit_code == 0 - mock_train.assert_called_once() + assert result.exit_code == 0 + mock_train.assert_called_once() def test_train_cli_overrides(self, cli_runner, tmp_path, valid_test_config): """Test CLI arguments properly override config values""" @@ -49,23 +51,25 @@ class TestTrainCommand(BaseCliTest): with patch("axolotl.cli.train.train") as mock_train: mock_train.return_value = (MagicMock(), MagicMock(), MagicMock()) + with patch("axolotl.cli.train.load_datasets") as mock_load_datasets: + mock_load_datasets.return_value = MagicMock() - result = cli_runner.invoke( - cli, - [ - "train", - str(config_path), - "--learning-rate", - "1e-4", - "--micro-batch-size", - "2", - "--no-accelerate", - ], - catch_exceptions=False, - ) + result = cli_runner.invoke( + cli, + [ + "train", + str(config_path), + "--learning-rate", + "1e-4", + "--micro-batch-size", + "2", + "--no-accelerate", + ], + catch_exceptions=False, + ) - assert result.exit_code == 0 - mock_train.assert_called_once() - cfg = mock_train.call_args[1]["cfg"] - assert cfg["learning_rate"] == 1e-4 - assert cfg["micro_batch_size"] == 2 + assert result.exit_code == 0 + mock_train.assert_called_once() + cfg = mock_train.call_args[1]["cfg"] + assert cfg["learning_rate"] == 1e-4 + assert cfg["micro_batch_size"] == 2