252 lines
8.9 KiB
Python
252 lines
8.9 KiB
Python
"""Tests for train CLI command."""
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from axolotl.cli.main import cli
|
|
|
|
from .test_cli_base import BaseCliTest
|
|
|
|
|
|
class TestTrainCommand(BaseCliTest):
|
|
"""Test cases for train command."""
|
|
|
|
cli = cli
|
|
|
|
def test_train_cli_validation(self, cli_runner):
|
|
"""Test CLI validation"""
|
|
self._test_cli_validation(cli_runner, "train")
|
|
|
|
def test_train_basic_execution(self, cli_runner, tmp_path, valid_test_config):
|
|
"""Test basic successful execution"""
|
|
self._test_basic_execution(
|
|
cli_runner, tmp_path, valid_test_config, "train", train=True
|
|
)
|
|
|
|
def test_train_basic_execution_no_accelerate(
|
|
self, cli_runner, tmp_path, valid_test_config
|
|
):
|
|
"""Test basic successful execution without accelerate"""
|
|
config_path = tmp_path / "config.yml"
|
|
config_path.write_text(valid_test_config)
|
|
|
|
with patch("axolotl.cli.train.train") as mock_train:
|
|
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
|
|
with patch("axolotl.cli.train.load_datasets") as mock_load_datasets:
|
|
mock_load_datasets.return_value = MagicMock()
|
|
|
|
result = cli_runner.invoke(
|
|
cli,
|
|
[
|
|
"train",
|
|
str(config_path),
|
|
"--launcher",
|
|
"python",
|
|
],
|
|
catch_exceptions=False,
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
mock_train.assert_called_once()
|
|
|
|
def test_train_cli_overrides(self, cli_runner, tmp_path, valid_test_config):
|
|
"""Test CLI arguments properly override config values"""
|
|
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
|
|
|
|
with patch("axolotl.cli.train.train") as mock_train:
|
|
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
|
|
with patch("axolotl.cli.train.load_datasets") as mock_load_datasets:
|
|
mock_load_datasets.return_value = MagicMock()
|
|
|
|
result = cli_runner.invoke(
|
|
cli,
|
|
[
|
|
"train",
|
|
str(config_path),
|
|
"--learning-rate=1e-4",
|
|
"--micro-batch-size=2",
|
|
"--launcher",
|
|
"python",
|
|
],
|
|
catch_exceptions=False,
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
mock_train.assert_called_once()
|
|
cfg = mock_train.call_args[1]["cfg"]
|
|
assert cfg["learning_rate"] == 1e-4
|
|
assert cfg["micro_batch_size"] == 2
|
|
|
|
def test_train_with_launcher_args_torchrun(
|
|
self, cli_runner, tmp_path, valid_test_config
|
|
):
|
|
"""Test train with torchrun launcher arguments"""
|
|
config_path = tmp_path / "config.yml"
|
|
config_path.write_text(valid_test_config)
|
|
|
|
with patch("os.execvpe") as mock_subprocess:
|
|
result = cli_runner.invoke(
|
|
cli,
|
|
[
|
|
"train",
|
|
str(config_path),
|
|
"--launcher",
|
|
"torchrun",
|
|
"--",
|
|
"--nproc_per_node=2",
|
|
"--nnodes=1",
|
|
],
|
|
catch_exceptions=False,
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
mock_subprocess.assert_called_once()
|
|
|
|
# Verify launcher args are passed to torchrun
|
|
called_cmd = mock_subprocess.call_args.args[1]
|
|
assert called_cmd[0] == "torchrun"
|
|
assert "--nproc_per_node=2" in called_cmd
|
|
assert "--nnodes=1" in called_cmd
|
|
assert "-m" in called_cmd
|
|
assert "axolotl.cli.train" in called_cmd
|
|
|
|
def test_train_with_launcher_args_accelerate(
|
|
self, cli_runner, tmp_path, valid_test_config
|
|
):
|
|
"""Test train with accelerate launcher arguments"""
|
|
config_path = tmp_path / "config.yml"
|
|
config_path.write_text(valid_test_config)
|
|
|
|
with patch("os.execvpe") as mock_subprocess:
|
|
result = cli_runner.invoke(
|
|
cli,
|
|
[
|
|
"train",
|
|
str(config_path),
|
|
"--launcher",
|
|
"accelerate",
|
|
"--",
|
|
"--config_file=accelerate_config.yml",
|
|
"--num_processes=4",
|
|
],
|
|
catch_exceptions=False,
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
mock_subprocess.assert_called_once()
|
|
|
|
# Verify launcher args are passed to accelerate
|
|
assert mock_subprocess.call_args.args[0] == "accelerate"
|
|
called_cmd = mock_subprocess.call_args.args[1]
|
|
assert called_cmd[0] == "accelerate"
|
|
assert called_cmd[1] == "launch"
|
|
assert "--config_file=accelerate_config.yml" in called_cmd
|
|
assert "--num_processes=4" in called_cmd
|
|
assert "-m" in called_cmd
|
|
assert "axolotl.cli.train" in called_cmd
|
|
|
|
def test_train_backward_compatibility_no_launcher_args(
|
|
self, cli_runner, tmp_path, valid_test_config
|
|
):
|
|
"""Test that existing train commands work without launcher args"""
|
|
config_path = tmp_path / "config.yml"
|
|
config_path.write_text(valid_test_config)
|
|
|
|
with patch("os.execvpe") as mock_subprocess:
|
|
result = cli_runner.invoke(
|
|
cli,
|
|
[
|
|
"train",
|
|
str(config_path),
|
|
"--launcher",
|
|
"accelerate",
|
|
"--learning-rate",
|
|
"1e-4",
|
|
],
|
|
catch_exceptions=False,
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
mock_subprocess.assert_called_once()
|
|
|
|
# Verify no launcher args contamination
|
|
assert mock_subprocess.call_args.args[0] == "accelerate"
|
|
called_cmd = mock_subprocess.call_args.args[1]
|
|
assert called_cmd[0] == "accelerate"
|
|
assert called_cmd[1] == "launch"
|
|
# Should not contain any extra launcher args
|
|
launcher_section = called_cmd[2 : called_cmd.index("-m")]
|
|
assert (
|
|
len(launcher_section) == 0
|
|
) # No launcher args between 'launch' and '-m'
|
|
|
|
def test_train_mixed_args_with_launcher_args(
|
|
self, cli_runner, tmp_path, valid_test_config
|
|
):
|
|
"""Test train with both regular CLI args and launcher args"""
|
|
config_path = tmp_path / "config.yml"
|
|
config_path.write_text(valid_test_config)
|
|
|
|
with patch("os.execvpe") as mock_subprocess:
|
|
result = cli_runner.invoke(
|
|
cli,
|
|
[
|
|
"train",
|
|
str(config_path),
|
|
"--launcher",
|
|
"torchrun",
|
|
"--learning-rate",
|
|
"2e-4",
|
|
"--micro-batch-size",
|
|
"4",
|
|
"--",
|
|
"--nproc_per_node=8",
|
|
],
|
|
catch_exceptions=False,
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
mock_subprocess.assert_called_once()
|
|
|
|
assert mock_subprocess.call_args.args[0] == "torchrun"
|
|
called_cmd = mock_subprocess.call_args.args[1]
|
|
# Verify launcher args
|
|
assert "--nproc_per_node=8" in called_cmd
|
|
# Verify axolotl args are also present
|
|
assert "--learning-rate=2e-4" in called_cmd
|
|
assert "--micro-batch-size=4" in called_cmd
|
|
|
|
def test_train_cloud_with_launcher_args(
|
|
self, cli_runner, tmp_path, valid_test_config
|
|
):
|
|
"""Test train with cloud and launcher arguments"""
|
|
config_path = tmp_path / "config.yml"
|
|
config_path.write_text(valid_test_config)
|
|
|
|
cloud_path = tmp_path / "cloud.yml"
|
|
cloud_path.write_text("provider: modal\ngpu: a100")
|
|
|
|
with patch("axolotl.cli.cloud.do_cli_train") as mock_cloud_train:
|
|
result = cli_runner.invoke(
|
|
cli,
|
|
[
|
|
"train",
|
|
str(config_path),
|
|
"--cloud",
|
|
str(cloud_path),
|
|
"--launcher",
|
|
"torchrun",
|
|
"--",
|
|
"--nproc_per_node=4",
|
|
"--nnodes=2",
|
|
],
|
|
catch_exceptions=False,
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
mock_cloud_train.assert_called_once()
|
|
|
|
# Verify cloud training was called with launcher args
|
|
call_kwargs = mock_cloud_train.call_args.kwargs
|
|
assert call_kwargs["launcher"] == "torchrun"
|
|
assert call_kwargs["launcher_args"] == ["--nproc_per_node=4", "--nnodes=2"]
|