use exec instead of subprocess to make ctrl+c nicer for cli (#3044)

* use exec instead of subprocess to make ctrl+c nicer for cli

* change var name to use_exec

* simplify to bool

* flush std*

* patch subprocess as mock in test

* fix tests

* more test fixes
This commit is contained in:
Wing Lian
2025-08-10 20:22:20 -04:00
parent 2c8497e489
commit 47304c7f8a
4 changed files with 64 additions and 23 deletions

View File

@@ -85,7 +85,7 @@ class TestTrainCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
@@ -104,7 +104,7 @@ class TestTrainCommand(BaseCliTest):
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[0]
called_cmd = mock_subprocess.call_args.args[1]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
@@ -118,7 +118,7 @@ class TestTrainCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
@@ -137,7 +137,8 @@ class TestTrainCommand(BaseCliTest):
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
called_cmd = mock_subprocess.call_args.args[0]
assert mock_subprocess.call_args.args[0] == "accelerate"
called_cmd = mock_subprocess.call_args.args[1]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
@@ -152,7 +153,7 @@ class TestTrainCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
@@ -170,7 +171,8 @@ class TestTrainCommand(BaseCliTest):
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
called_cmd = mock_subprocess.call_args.args[0]
assert mock_subprocess.call_args.args[0] == "accelerate"
called_cmd = mock_subprocess.call_args.args[1]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
@@ -186,7 +188,7 @@ class TestTrainCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
@@ -207,7 +209,8 @@ class TestTrainCommand(BaseCliTest):
assert result.exit_code == 0
mock_subprocess.assert_called_once()
called_cmd = mock_subprocess.call_args.args[0]
assert mock_subprocess.call_args.args[0] == "torchrun"
called_cmd = mock_subprocess.call_args.args[1]
# Verify launcher args
assert "--nproc_per_node=8" in called_cmd
# Verify axolotl args are also present