From 47304c7f8acf11b54e66c247a8a995b9e95658b4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 10 Aug 2025 20:22:20 -0400 Subject: [PATCH] use exec instead of subprocess to make ctrl+c nicer for cli (#3044) * use exec instead of subprocess to make ctrl+c nicer for cli * change var name to use_exec * simplify to bool * flush std* * patch subprocess as mock in test * fix tests * more test fixes --- src/axolotl/cli/main.py | 5 ++-- src/axolotl/cli/utils/train.py | 51 +++++++++++++++++++++++++++------- tests/cli/test_cli_base.py | 12 ++++++-- tests/cli/test_cli_train.py | 19 +++++++------ 4 files changed, 64 insertions(+), 23 deletions(-) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index c41acc40b..e63392802 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -123,9 +123,10 @@ def train( _launcher = None if kwargs.get("use_ray") else launcher # Process each configuration - for cfg_file in generate_config_files(config, sweep): + for cfg_file, is_group in generate_config_files(config, sweep): try: - launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args) + use_exec = is_group is not True + launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args, use_exec) except subprocess.CalledProcessError as exc: LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}") if not sweep: diff --git a/src/axolotl/cli/utils/train.py b/src/axolotl/cli/utils/train.py index 61d05e52b..3f9a6e4db 100644 --- a/src/axolotl/cli/utils/train.py +++ b/src/axolotl/cli/utils/train.py @@ -2,6 +2,7 @@ import os import subprocess # nosec +import sys import tempfile from typing import Any, Iterator, Literal @@ -64,10 +65,20 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]: return cmd -def generate_config_files(config: str, sweep: str | None) -> Iterator[str]: - """Generate list of configuration files to process.""" +def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, bool]]: + """ + Generate list of configuration files to process. + + Args: + config: Base configuration file + sweep: Sweep configuration file + + Yields: + Tuple of configuration file name and whether this is a group of configurations + """ + if not sweep: - yield config + yield config, False return # Load sweep and base configurations @@ -78,6 +89,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]: # Generate all possible configurations permutations = generate_sweep_configs(base_config, sweep_config) + is_group = len(permutations) > 1 for permutation in permutations: # pylint: disable=consider-using-with temp_file = tempfile.NamedTemporaryFile( @@ -88,7 +100,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]: ) yaml.dump(permutation, temp_file) temp_file.close() - yield temp_file.name + yield temp_file.name, is_group def launch_training( @@ -97,6 +109,7 @@ def launch_training( cloud: str | None, kwargs: dict, launcher_args: list[str] | None = None, + use_exec: bool = False, ) -> None: """Execute training with the given configuration.""" launcher_args = launcher_args or [] @@ -105,9 +118,9 @@ def launch_training( _launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args) elif launcher: if launcher == "accelerate": - _launch_accelerate_training(cfg_file, kwargs, launcher_args) + _launch_accelerate_training(cfg_file, kwargs, launcher_args, use_exec) elif launcher == "torchrun": - _launch_torchrun_training(cfg_file, kwargs, launcher_args) + _launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec) elif launcher == "python": _launch_python_training(cfg_file, kwargs) @@ -136,7 +149,10 @@ def _launch_cloud_training( def _launch_accelerate_training( - cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None + cfg_file: str, + kwargs: dict, + launcher_args: list[str] | None = None, + use_exec: bool = False, ) -> None: """Execute training via accelerate launcher.""" launcher_args = launcher_args or [] @@ -161,11 +177,20 @@ def _launch_accelerate_training( base_cmd.append(cfg_file) cmd = build_command(base_cmd, kwargs) - subprocess.run(cmd, check=True) # nosec B603 + if use_exec: + # make sure to flush stdout and stderr before replacing the process + sys.stdout.flush() + sys.stderr.flush() + os.execvpe(cmd[0], cmd, os.environ) # nosec B606 + else: + subprocess.run(cmd, check=True) # nosec B603 def _launch_torchrun_training( - cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None + cfg_file: str, + kwargs: dict, + launcher_args: list[str] | None = None, + use_exec: bool = False, ) -> None: """Execute training via torchrun launcher.""" launcher_args = launcher_args or [] @@ -178,7 +203,13 @@ def _launch_torchrun_training( base_cmd.append(cfg_file) cmd = build_command(base_cmd, kwargs) - subprocess.run(cmd, check=True) # nosec B603 + if use_exec: + # make sure to flush stdout and stderr before replacing the process + sys.stdout.flush() + sys.stderr.flush() + os.execvpe(cmd[0], cmd, os.environ) # nosec B606 + else: + subprocess.run(cmd, check=True) # nosec B603 def _launch_python_training(cfg_file: str, kwargs: dict) -> None: diff --git a/tests/cli/test_cli_base.py b/tests/cli/test_cli_base.py index 4b880d44a..e28bbb75c 100644 --- a/tests/cli/test_cli_base.py +++ b/tests/cli/test_cli_base.py @@ -47,7 +47,9 @@ class BaseCliTest: config_path = tmp_path / "config.yml" config_path.write_text(valid_test_config) - with patch("subprocess.run") as mock: + mock_fn = "os.execvpe" if command == "train" else "subprocess.run" + + with patch(mock_fn) as mock: result = cli_runner.invoke(cli, [command, str(config_path)]) assert mock.called @@ -65,8 +67,12 @@ class BaseCliTest: if train: expected.append("--shard=False") - assert mock.call_args.args[0] == expected - assert mock.call_args.kwargs == {"check": True} + if command == "train": + assert mock.call_args.args[0] == "accelerate" + assert mock.call_args.args[1] == expected + else: + assert mock.call_args.args[0] == expected + assert mock.call_args.kwargs == {"check": True} assert result.exit_code == 0 def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str): diff --git a/tests/cli/test_cli_train.py b/tests/cli/test_cli_train.py index 9b266f057..d4d90f57f 100644 --- a/tests/cli/test_cli_train.py +++ b/tests/cli/test_cli_train.py @@ -85,7 +85,7 @@ class TestTrainCommand(BaseCliTest): config_path = tmp_path / "config.yml" config_path.write_text(valid_test_config) - with patch("subprocess.run") as mock_subprocess: + with patch("os.execvpe") as mock_subprocess: result = cli_runner.invoke( cli, [ @@ -104,7 +104,7 @@ class TestTrainCommand(BaseCliTest): mock_subprocess.assert_called_once() # Verify launcher args are passed to torchrun - called_cmd = mock_subprocess.call_args.args[0] + called_cmd = mock_subprocess.call_args.args[1] assert called_cmd[0] == "torchrun" assert "--nproc_per_node=2" in called_cmd assert "--nnodes=1" in called_cmd @@ -118,7 +118,7 @@ class TestTrainCommand(BaseCliTest): config_path = tmp_path / "config.yml" config_path.write_text(valid_test_config) - with patch("subprocess.run") as mock_subprocess: + with patch("os.execvpe") as mock_subprocess: result = cli_runner.invoke( cli, [ @@ -137,7 +137,8 @@ class TestTrainCommand(BaseCliTest): mock_subprocess.assert_called_once() # Verify launcher args are passed to accelerate - called_cmd = mock_subprocess.call_args.args[0] + assert mock_subprocess.call_args.args[0] == "accelerate" + called_cmd = mock_subprocess.call_args.args[1] assert called_cmd[0] == "accelerate" assert called_cmd[1] == "launch" assert "--config_file=accelerate_config.yml" in called_cmd @@ -152,7 +153,7 @@ class TestTrainCommand(BaseCliTest): config_path = tmp_path / "config.yml" config_path.write_text(valid_test_config) - with patch("subprocess.run") as mock_subprocess: + with patch("os.execvpe") as mock_subprocess: result = cli_runner.invoke( cli, [ @@ -170,7 +171,8 @@ class TestTrainCommand(BaseCliTest): mock_subprocess.assert_called_once() # Verify no launcher args contamination - called_cmd = mock_subprocess.call_args.args[0] + assert mock_subprocess.call_args.args[0] == "accelerate" + called_cmd = mock_subprocess.call_args.args[1] assert called_cmd[0] == "accelerate" assert called_cmd[1] == "launch" # Should not contain any extra launcher args @@ -186,7 +188,7 @@ class TestTrainCommand(BaseCliTest): config_path = tmp_path / "config.yml" config_path.write_text(valid_test_config) - with patch("subprocess.run") as mock_subprocess: + with patch("os.execvpe") as mock_subprocess: result = cli_runner.invoke( cli, [ @@ -207,7 +209,8 @@ class TestTrainCommand(BaseCliTest): assert result.exit_code == 0 mock_subprocess.assert_called_once() - called_cmd = mock_subprocess.call_args.args[0] + assert mock_subprocess.call_args.args[0] == "torchrun" + called_cmd = mock_subprocess.call_args.args[1] # Verify launcher args assert "--nproc_per_node=8" in called_cmd # Verify axolotl args are also present