use exec instead of subprocess to make ctrl+c nicer for cli (#3044)
* use exec instead of subprocess to make ctrl+c nicer for cli * change var name to use_exec * simplify to bool * flush std* * patch subprocess as mock in test * fix tests * more test fixes
This commit is contained in:
@@ -123,9 +123,10 @@ def train(
|
|||||||
_launcher = None if kwargs.get("use_ray") else launcher
|
_launcher = None if kwargs.get("use_ray") else launcher
|
||||||
|
|
||||||
# Process each configuration
|
# Process each configuration
|
||||||
for cfg_file in generate_config_files(config, sweep):
|
for cfg_file, is_group in generate_config_files(config, sweep):
|
||||||
try:
|
try:
|
||||||
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args)
|
use_exec = is_group is not True
|
||||||
|
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args, use_exec)
|
||||||
except subprocess.CalledProcessError as exc:
|
except subprocess.CalledProcessError as exc:
|
||||||
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||||
if not sweep:
|
if not sweep:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import subprocess # nosec
|
import subprocess # nosec
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any, Iterator, Literal
|
from typing import Any, Iterator, Literal
|
||||||
|
|
||||||
@@ -64,10 +65,20 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
|||||||
return cmd
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, bool]]:
|
||||||
"""Generate list of configuration files to process."""
|
"""
|
||||||
|
Generate list of configuration files to process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Base configuration file
|
||||||
|
sweep: Sweep configuration file
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Tuple of configuration file name and whether this is a group of configurations
|
||||||
|
"""
|
||||||
|
|
||||||
if not sweep:
|
if not sweep:
|
||||||
yield config
|
yield config, False
|
||||||
return
|
return
|
||||||
|
|
||||||
# Load sweep and base configurations
|
# Load sweep and base configurations
|
||||||
@@ -78,6 +89,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
|||||||
|
|
||||||
# Generate all possible configurations
|
# Generate all possible configurations
|
||||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||||
|
is_group = len(permutations) > 1
|
||||||
for permutation in permutations:
|
for permutation in permutations:
|
||||||
# pylint: disable=consider-using-with
|
# pylint: disable=consider-using-with
|
||||||
temp_file = tempfile.NamedTemporaryFile(
|
temp_file = tempfile.NamedTemporaryFile(
|
||||||
@@ -88,7 +100,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
|||||||
)
|
)
|
||||||
yaml.dump(permutation, temp_file)
|
yaml.dump(permutation, temp_file)
|
||||||
temp_file.close()
|
temp_file.close()
|
||||||
yield temp_file.name
|
yield temp_file.name, is_group
|
||||||
|
|
||||||
|
|
||||||
def launch_training(
|
def launch_training(
|
||||||
@@ -97,6 +109,7 @@ def launch_training(
|
|||||||
cloud: str | None,
|
cloud: str | None,
|
||||||
kwargs: dict,
|
kwargs: dict,
|
||||||
launcher_args: list[str] | None = None,
|
launcher_args: list[str] | None = None,
|
||||||
|
use_exec: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute training with the given configuration."""
|
"""Execute training with the given configuration."""
|
||||||
launcher_args = launcher_args or []
|
launcher_args = launcher_args or []
|
||||||
@@ -105,9 +118,9 @@ def launch_training(
|
|||||||
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
|
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
|
||||||
elif launcher:
|
elif launcher:
|
||||||
if launcher == "accelerate":
|
if launcher == "accelerate":
|
||||||
_launch_accelerate_training(cfg_file, kwargs, launcher_args)
|
_launch_accelerate_training(cfg_file, kwargs, launcher_args, use_exec)
|
||||||
elif launcher == "torchrun":
|
elif launcher == "torchrun":
|
||||||
_launch_torchrun_training(cfg_file, kwargs, launcher_args)
|
_launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec)
|
||||||
elif launcher == "python":
|
elif launcher == "python":
|
||||||
_launch_python_training(cfg_file, kwargs)
|
_launch_python_training(cfg_file, kwargs)
|
||||||
|
|
||||||
@@ -136,7 +149,10 @@ def _launch_cloud_training(
|
|||||||
|
|
||||||
|
|
||||||
def _launch_accelerate_training(
|
def _launch_accelerate_training(
|
||||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
cfg_file: str,
|
||||||
|
kwargs: dict,
|
||||||
|
launcher_args: list[str] | None = None,
|
||||||
|
use_exec: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute training via accelerate launcher."""
|
"""Execute training via accelerate launcher."""
|
||||||
launcher_args = launcher_args or []
|
launcher_args = launcher_args or []
|
||||||
@@ -161,11 +177,20 @@ def _launch_accelerate_training(
|
|||||||
base_cmd.append(cfg_file)
|
base_cmd.append(cfg_file)
|
||||||
|
|
||||||
cmd = build_command(base_cmd, kwargs)
|
cmd = build_command(base_cmd, kwargs)
|
||||||
subprocess.run(cmd, check=True) # nosec B603
|
if use_exec:
|
||||||
|
# make sure to flush stdout and stderr before replacing the process
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
||||||
|
else:
|
||||||
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
|
|
||||||
|
|
||||||
def _launch_torchrun_training(
|
def _launch_torchrun_training(
|
||||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
cfg_file: str,
|
||||||
|
kwargs: dict,
|
||||||
|
launcher_args: list[str] | None = None,
|
||||||
|
use_exec: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute training via torchrun launcher."""
|
"""Execute training via torchrun launcher."""
|
||||||
launcher_args = launcher_args or []
|
launcher_args = launcher_args or []
|
||||||
@@ -178,7 +203,13 @@ def _launch_torchrun_training(
|
|||||||
base_cmd.append(cfg_file)
|
base_cmd.append(cfg_file)
|
||||||
|
|
||||||
cmd = build_command(base_cmd, kwargs)
|
cmd = build_command(base_cmd, kwargs)
|
||||||
subprocess.run(cmd, check=True) # nosec B603
|
if use_exec:
|
||||||
|
# make sure to flush stdout and stderr before replacing the process
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
||||||
|
else:
|
||||||
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
|
|
||||||
|
|
||||||
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
|
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
|
||||||
|
|||||||
@@ -47,7 +47,9 @@ class BaseCliTest:
|
|||||||
config_path = tmp_path / "config.yml"
|
config_path = tmp_path / "config.yml"
|
||||||
config_path.write_text(valid_test_config)
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
with patch("subprocess.run") as mock:
|
mock_fn = "os.execvpe" if command == "train" else "subprocess.run"
|
||||||
|
|
||||||
|
with patch(mock_fn) as mock:
|
||||||
result = cli_runner.invoke(cli, [command, str(config_path)])
|
result = cli_runner.invoke(cli, [command, str(config_path)])
|
||||||
|
|
||||||
assert mock.called
|
assert mock.called
|
||||||
@@ -65,8 +67,12 @@ class BaseCliTest:
|
|||||||
if train:
|
if train:
|
||||||
expected.append("--shard=False")
|
expected.append("--shard=False")
|
||||||
|
|
||||||
assert mock.call_args.args[0] == expected
|
if command == "train":
|
||||||
assert mock.call_args.kwargs == {"check": True}
|
assert mock.call_args.args[0] == "accelerate"
|
||||||
|
assert mock.call_args.args[1] == expected
|
||||||
|
else:
|
||||||
|
assert mock.call_args.args[0] == expected
|
||||||
|
assert mock.call_args.kwargs == {"check": True}
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
|
|
||||||
def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):
|
def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
config_path = tmp_path / "config.yml"
|
config_path = tmp_path / "config.yml"
|
||||||
config_path.write_text(valid_test_config)
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
with patch("subprocess.run") as mock_subprocess:
|
with patch("os.execvpe") as mock_subprocess:
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
@@ -104,7 +104,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
mock_subprocess.assert_called_once()
|
mock_subprocess.assert_called_once()
|
||||||
|
|
||||||
# Verify launcher args are passed to torchrun
|
# Verify launcher args are passed to torchrun
|
||||||
called_cmd = mock_subprocess.call_args.args[0]
|
called_cmd = mock_subprocess.call_args.args[1]
|
||||||
assert called_cmd[0] == "torchrun"
|
assert called_cmd[0] == "torchrun"
|
||||||
assert "--nproc_per_node=2" in called_cmd
|
assert "--nproc_per_node=2" in called_cmd
|
||||||
assert "--nnodes=1" in called_cmd
|
assert "--nnodes=1" in called_cmd
|
||||||
@@ -118,7 +118,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
config_path = tmp_path / "config.yml"
|
config_path = tmp_path / "config.yml"
|
||||||
config_path.write_text(valid_test_config)
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
with patch("subprocess.run") as mock_subprocess:
|
with patch("os.execvpe") as mock_subprocess:
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
@@ -137,7 +137,8 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
mock_subprocess.assert_called_once()
|
mock_subprocess.assert_called_once()
|
||||||
|
|
||||||
# Verify launcher args are passed to accelerate
|
# Verify launcher args are passed to accelerate
|
||||||
called_cmd = mock_subprocess.call_args.args[0]
|
assert mock_subprocess.call_args.args[0] == "accelerate"
|
||||||
|
called_cmd = mock_subprocess.call_args.args[1]
|
||||||
assert called_cmd[0] == "accelerate"
|
assert called_cmd[0] == "accelerate"
|
||||||
assert called_cmd[1] == "launch"
|
assert called_cmd[1] == "launch"
|
||||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||||
@@ -152,7 +153,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
config_path = tmp_path / "config.yml"
|
config_path = tmp_path / "config.yml"
|
||||||
config_path.write_text(valid_test_config)
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
with patch("subprocess.run") as mock_subprocess:
|
with patch("os.execvpe") as mock_subprocess:
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
@@ -170,7 +171,8 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
mock_subprocess.assert_called_once()
|
mock_subprocess.assert_called_once()
|
||||||
|
|
||||||
# Verify no launcher args contamination
|
# Verify no launcher args contamination
|
||||||
called_cmd = mock_subprocess.call_args.args[0]
|
assert mock_subprocess.call_args.args[0] == "accelerate"
|
||||||
|
called_cmd = mock_subprocess.call_args.args[1]
|
||||||
assert called_cmd[0] == "accelerate"
|
assert called_cmd[0] == "accelerate"
|
||||||
assert called_cmd[1] == "launch"
|
assert called_cmd[1] == "launch"
|
||||||
# Should not contain any extra launcher args
|
# Should not contain any extra launcher args
|
||||||
@@ -186,7 +188,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
config_path = tmp_path / "config.yml"
|
config_path = tmp_path / "config.yml"
|
||||||
config_path.write_text(valid_test_config)
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
with patch("subprocess.run") as mock_subprocess:
|
with patch("os.execvpe") as mock_subprocess:
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
@@ -207,7 +209,8 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
mock_subprocess.assert_called_once()
|
mock_subprocess.assert_called_once()
|
||||||
|
|
||||||
called_cmd = mock_subprocess.call_args.args[0]
|
assert mock_subprocess.call_args.args[0] == "torchrun"
|
||||||
|
called_cmd = mock_subprocess.call_args.args[1]
|
||||||
# Verify launcher args
|
# Verify launcher args
|
||||||
assert "--nproc_per_node=8" in called_cmd
|
assert "--nproc_per_node=8" in called_cmd
|
||||||
# Verify axolotl args are also present
|
# Verify axolotl args are also present
|
||||||
|
|||||||
Reference in New Issue
Block a user