use exec instead of subprocess to make ctrl+c nicer for cli (#3044)

* use exec instead of subprocess to make ctrl+c nicer for cli

* change var name to use_exec

* simplify to bool

* flush std*

* patch subprocess as mock in test

* fix tests

* more test fixes
This commit is contained in:
Wing Lian
2025-08-10 20:22:20 -04:00
parent 2c8497e489
commit 47304c7f8a
4 changed files with 64 additions and 23 deletions

View File

@@ -123,9 +123,10 @@ def train(
_launcher = None if kwargs.get("use_ray") else launcher
# Process each configuration
for cfg_file in generate_config_files(config, sweep):
for cfg_file, is_group in generate_config_files(config, sweep):
try:
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args)
use_exec = is_group is not True
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args, use_exec)
except subprocess.CalledProcessError as exc:
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep:

View File

@@ -2,6 +2,7 @@
import os
import subprocess # nosec
import sys
import tempfile
from typing import Any, Iterator, Literal
@@ -64,10 +65,20 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
return cmd
def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
"""Generate list of configuration files to process."""
def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, bool]]:
"""
Generate list of configuration files to process.
Args:
config: Base configuration file
sweep: Sweep configuration file
Yields:
Tuple of configuration file name and whether this is a group of configurations
"""
if not sweep:
yield config
yield config, False
return
# Load sweep and base configurations
@@ -78,6 +89,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
# Generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)
is_group = len(permutations) > 1
for permutation in permutations:
# pylint: disable=consider-using-with
temp_file = tempfile.NamedTemporaryFile(
@@ -88,7 +100,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
)
yaml.dump(permutation, temp_file)
temp_file.close()
yield temp_file.name
yield temp_file.name, is_group
def launch_training(
@@ -97,6 +109,7 @@ def launch_training(
cloud: str | None,
kwargs: dict,
launcher_args: list[str] | None = None,
use_exec: bool = False,
) -> None:
"""Execute training with the given configuration."""
launcher_args = launcher_args or []
@@ -105,9 +118,9 @@ def launch_training(
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
elif launcher:
if launcher == "accelerate":
_launch_accelerate_training(cfg_file, kwargs, launcher_args)
_launch_accelerate_training(cfg_file, kwargs, launcher_args, use_exec)
elif launcher == "torchrun":
_launch_torchrun_training(cfg_file, kwargs, launcher_args)
_launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec)
elif launcher == "python":
_launch_python_training(cfg_file, kwargs)
@@ -136,7 +149,10 @@ def _launch_cloud_training(
def _launch_accelerate_training(
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
cfg_file: str,
kwargs: dict,
launcher_args: list[str] | None = None,
use_exec: bool = False,
) -> None:
"""Execute training via accelerate launcher."""
launcher_args = launcher_args or []
@@ -161,11 +177,20 @@ def _launch_accelerate_training(
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
if use_exec:
# make sure to flush stdout and stderr before replacing the process
sys.stdout.flush()
sys.stderr.flush()
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
else:
subprocess.run(cmd, check=True) # nosec B603
def _launch_torchrun_training(
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
cfg_file: str,
kwargs: dict,
launcher_args: list[str] | None = None,
use_exec: bool = False,
) -> None:
"""Execute training via torchrun launcher."""
launcher_args = launcher_args or []
@@ -178,7 +203,13 @@ def _launch_torchrun_training(
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
if use_exec:
# make sure to flush stdout and stderr before replacing the process
sys.stdout.flush()
sys.stderr.flush()
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
else:
subprocess.run(cmd, check=True) # nosec B603
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:

View File

@@ -47,7 +47,9 @@ class BaseCliTest:
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock:
mock_fn = "os.execvpe" if command == "train" else "subprocess.run"
with patch(mock_fn) as mock:
result = cli_runner.invoke(cli, [command, str(config_path)])
assert mock.called
@@ -65,8 +67,12 @@ class BaseCliTest:
if train:
expected.append("--shard=False")
assert mock.call_args.args[0] == expected
assert mock.call_args.kwargs == {"check": True}
if command == "train":
assert mock.call_args.args[0] == "accelerate"
assert mock.call_args.args[1] == expected
else:
assert mock.call_args.args[0] == expected
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0
def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):

View File

@@ -85,7 +85,7 @@ class TestTrainCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
@@ -104,7 +104,7 @@ class TestTrainCommand(BaseCliTest):
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[0]
called_cmd = mock_subprocess.call_args.args[1]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
@@ -118,7 +118,7 @@ class TestTrainCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
@@ -137,7 +137,8 @@ class TestTrainCommand(BaseCliTest):
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
called_cmd = mock_subprocess.call_args.args[0]
assert mock_subprocess.call_args.args[0] == "accelerate"
called_cmd = mock_subprocess.call_args.args[1]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
@@ -152,7 +153,7 @@ class TestTrainCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
@@ -170,7 +171,8 @@ class TestTrainCommand(BaseCliTest):
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
called_cmd = mock_subprocess.call_args.args[0]
assert mock_subprocess.call_args.args[0] == "accelerate"
called_cmd = mock_subprocess.call_args.args[1]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
@@ -186,7 +188,7 @@ class TestTrainCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
@@ -207,7 +209,8 @@ class TestTrainCommand(BaseCliTest):
assert result.exit_code == 0
mock_subprocess.assert_called_once()
called_cmd = mock_subprocess.call_args.args[0]
assert mock_subprocess.call_args.args[0] == "torchrun"
called_cmd = mock_subprocess.call_args.args[1]
# Verify launcher args
assert "--nproc_per_node=8" in called_cmd
# Verify axolotl args are also present