From c82cbdc6d918220c7e68ec65672c6cd6c4ac3680 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 6 Feb 2025 13:50:39 -0500 Subject: [PATCH] make sure to handle num-processes with cloud --- src/axolotl/cli/cloud/__init__.py | 3 ++- src/axolotl/cli/cloud/modal_.py | 10 +++++++--- src/axolotl/cli/main.py | 4 +++- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/axolotl/cli/cloud/__init__.py b/src/axolotl/cli/cloud/__init__.py index fde46e397..40e192549 100644 --- a/src/axolotl/cli/cloud/__init__.py +++ b/src/axolotl/cli/cloud/__init__.py @@ -35,13 +35,14 @@ def do_cli_train( cloud_config: Union[Path, str], config: Union[Path, str], accelerate: bool = True, + **kwargs, ) -> None: print_axolotl_text_art() cloud_cfg = load_cloud_cfg(cloud_config) cloud = ModalCloud(cloud_cfg) with open(config, "r", encoding="utf-8") as file: config_yaml = file.read() - cloud.train(config_yaml, accelerate=accelerate) + cloud.train(config_yaml, accelerate=accelerate, **kwargs) def do_cli_lm_eval( diff --git a/src/axolotl/cli/cloud/modal_.py b/src/axolotl/cli/cloud/modal_.py index 52e8cc3c6..ced244f21 100644 --- a/src/axolotl/cli/cloud/modal_.py +++ b/src/axolotl/cli/cloud/modal_.py @@ -217,7 +217,7 @@ class ModalCloud(Cloud): secrets=self.get_secrets(), ) - def train(self, config_yaml: str, accelerate: bool = True): + def train(self, config_yaml: str, accelerate: bool = True, **kwargs): modal_fn = self.get_train_env()(_train) with modal.enable_output(): with self.app.run(detach=True): @@ -225,6 +225,7 @@ class ModalCloud(Cloud): config_yaml, accelerate=accelerate, volumes={k: v[0] for k, v in self.volumes.items()}, + **kwargs, ) def lm_eval(self, config_yaml: str): @@ -255,7 +256,7 @@ def _preprocess(config_yaml: str, volumes=None): ) -def _train(config_yaml: str, accelerate: bool = True, volumes=None): +def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs): with open( "/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8" ) as f_out: @@ -265,8 +266,11 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None): accelerate_args = "--accelerate" else: accelerate_args = "--no-accelerate" + num_processes_args = "" + if num_processes := kwargs.pop("num_processes", None): + num_processes_args = f"--num-processes {num_processes}" run_cmd( - f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml", + f"axolotl train {accelerate_args} {num_processes_args} /workspace/artifacts/axolotl/config.yaml", run_folder, volumes, ) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index e8dbf9f12..d1a85b929 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -221,7 +221,9 @@ def train( subprocess.run(cmd, check=True) # nosec B603 else: if cloud: - do_cli_train(cloud_config=cloud, config=config, accelerate=False) + do_cli_train( + cloud_config=cloud, config=config, accelerate=False, **kwargs + ) else: from axolotl.cli.train import do_cli