diff --git a/src/axolotl/cli/cloud/__init__.py b/src/axolotl/cli/cloud/__init__.py index fde46e397..40e192549 100644 --- a/src/axolotl/cli/cloud/__init__.py +++ b/src/axolotl/cli/cloud/__init__.py @@ -35,13 +35,14 @@ def do_cli_train( cloud_config: Union[Path, str], config: Union[Path, str], accelerate: bool = True, + **kwargs, ) -> None: print_axolotl_text_art() cloud_cfg = load_cloud_cfg(cloud_config) cloud = ModalCloud(cloud_cfg) with open(config, "r", encoding="utf-8") as file: config_yaml = file.read() - cloud.train(config_yaml, accelerate=accelerate) + cloud.train(config_yaml, accelerate=accelerate, **kwargs) def do_cli_lm_eval( diff --git a/src/axolotl/cli/cloud/modal_.py b/src/axolotl/cli/cloud/modal_.py index 52e8cc3c6..ced244f21 100644 --- a/src/axolotl/cli/cloud/modal_.py +++ b/src/axolotl/cli/cloud/modal_.py @@ -217,7 +217,7 @@ class ModalCloud(Cloud): secrets=self.get_secrets(), ) - def train(self, config_yaml: str, accelerate: bool = True): + def train(self, config_yaml: str, accelerate: bool = True, **kwargs): modal_fn = self.get_train_env()(_train) with modal.enable_output(): with self.app.run(detach=True): @@ -225,6 +225,7 @@ class ModalCloud(Cloud): config_yaml, accelerate=accelerate, volumes={k: v[0] for k, v in self.volumes.items()}, + **kwargs, ) def lm_eval(self, config_yaml: str): @@ -255,7 +256,7 @@ def _preprocess(config_yaml: str, volumes=None): ) -def _train(config_yaml: str, accelerate: bool = True, volumes=None): +def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs): with open( "/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8" ) as f_out: @@ -265,8 +266,11 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None): accelerate_args = "--accelerate" else: accelerate_args = "--no-accelerate" + num_processes_args = "" + if num_processes := kwargs.pop("num_processes", None): + num_processes_args = f"--num-processes {num_processes}" run_cmd( - f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml", + f"axolotl train {accelerate_args} {num_processes_args} /workspace/artifacts/axolotl/config.yaml", run_folder, volumes, ) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index e8dbf9f12..d1a85b929 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -221,7 +221,9 @@ def train( subprocess.run(cmd, check=True) # nosec B603 else: if cloud: - do_cli_train(cloud_config=cloud, config=config, accelerate=False) + do_cli_train( + cloud_config=cloud, config=config, accelerate=False, **kwargs + ) else: from axolotl.cli.train import do_cli