use exec instead of subprocess to make ctrl+c nicer for cli (#3044)
* use exec instead of subprocess to make ctrl+c nicer for cli * change var name to use_exec * simplify to bool * flush std* * patch subprocess as mock in test * fix tests * more test fixes
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import os
|
||||
import subprocess # nosec
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Any, Iterator, Literal
|
||||
|
||||
@@ -64,10 +65,20 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
||||
return cmd
|
||||
|
||||
|
||||
def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
||||
"""Generate list of configuration files to process."""
|
||||
def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, bool]]:
|
||||
"""
|
||||
Generate list of configuration files to process.
|
||||
|
||||
Args:
|
||||
config: Base configuration file
|
||||
sweep: Sweep configuration file
|
||||
|
||||
Yields:
|
||||
Tuple of configuration file name and whether this is a group of configurations
|
||||
"""
|
||||
|
||||
if not sweep:
|
||||
yield config
|
||||
yield config, False
|
||||
return
|
||||
|
||||
# Load sweep and base configurations
|
||||
@@ -78,6 +89,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
||||
|
||||
# Generate all possible configurations
|
||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||
is_group = len(permutations) > 1
|
||||
for permutation in permutations:
|
||||
# pylint: disable=consider-using-with
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
@@ -88,7 +100,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
||||
)
|
||||
yaml.dump(permutation, temp_file)
|
||||
temp_file.close()
|
||||
yield temp_file.name
|
||||
yield temp_file.name, is_group
|
||||
|
||||
|
||||
def launch_training(
|
||||
@@ -97,6 +109,7 @@ def launch_training(
|
||||
cloud: str | None,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
use_exec: bool = False,
|
||||
) -> None:
|
||||
"""Execute training with the given configuration."""
|
||||
launcher_args = launcher_args or []
|
||||
@@ -105,9 +118,9 @@ def launch_training(
|
||||
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
|
||||
elif launcher:
|
||||
if launcher == "accelerate":
|
||||
_launch_accelerate_training(cfg_file, kwargs, launcher_args)
|
||||
_launch_accelerate_training(cfg_file, kwargs, launcher_args, use_exec)
|
||||
elif launcher == "torchrun":
|
||||
_launch_torchrun_training(cfg_file, kwargs, launcher_args)
|
||||
_launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec)
|
||||
elif launcher == "python":
|
||||
_launch_python_training(cfg_file, kwargs)
|
||||
|
||||
@@ -136,7 +149,10 @@ def _launch_cloud_training(
|
||||
|
||||
|
||||
def _launch_accelerate_training(
|
||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
||||
cfg_file: str,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
use_exec: bool = False,
|
||||
) -> None:
|
||||
"""Execute training via accelerate launcher."""
|
||||
launcher_args = launcher_args or []
|
||||
@@ -161,11 +177,20 @@ def _launch_accelerate_training(
|
||||
base_cmd.append(cfg_file)
|
||||
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
if use_exec:
|
||||
# make sure to flush stdout and stderr before replacing the process
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
||||
else:
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
|
||||
|
||||
def _launch_torchrun_training(
|
||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
||||
cfg_file: str,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
use_exec: bool = False,
|
||||
) -> None:
|
||||
"""Execute training via torchrun launcher."""
|
||||
launcher_args = launcher_args or []
|
||||
@@ -178,7 +203,13 @@ def _launch_torchrun_training(
|
||||
base_cmd.append(cfg_file)
|
||||
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
if use_exec:
|
||||
# make sure to flush stdout and stderr before replacing the process
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
||||
else:
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
|
||||
|
||||
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
|
||||
|
||||
Reference in New Issue
Block a user