make sure to handle num-processes with cloud

This commit is contained in:
Wing Lian
2025-02-06 13:50:39 -05:00
parent ecea44c902
commit c82cbdc6d9
3 changed files with 12 additions and 5 deletions

View File

@@ -35,13 +35,14 @@ def do_cli_train(
cloud_config: Union[Path, str], cloud_config: Union[Path, str],
config: Union[Path, str], config: Union[Path, str],
accelerate: bool = True, accelerate: bool = True,
**kwargs,
) -> None: ) -> None:
print_axolotl_text_art() print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config) cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg) cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file: with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read() config_yaml = file.read()
cloud.train(config_yaml, accelerate=accelerate) cloud.train(config_yaml, accelerate=accelerate, **kwargs)
def do_cli_lm_eval( def do_cli_lm_eval(

View File

@@ -217,7 +217,7 @@ class ModalCloud(Cloud):
secrets=self.get_secrets(), secrets=self.get_secrets(),
) )
def train(self, config_yaml: str, accelerate: bool = True): def train(self, config_yaml: str, accelerate: bool = True, **kwargs):
modal_fn = self.get_train_env()(_train) modal_fn = self.get_train_env()(_train)
with modal.enable_output(): with modal.enable_output():
with self.app.run(detach=True): with self.app.run(detach=True):
@@ -225,6 +225,7 @@ class ModalCloud(Cloud):
config_yaml, config_yaml,
accelerate=accelerate, accelerate=accelerate,
volumes={k: v[0] for k, v in self.volumes.items()}, volumes={k: v[0] for k, v in self.volumes.items()},
**kwargs,
) )
def lm_eval(self, config_yaml: str): def lm_eval(self, config_yaml: str):
@@ -255,7 +256,7 @@ def _preprocess(config_yaml: str, volumes=None):
) )
def _train(config_yaml: str, accelerate: bool = True, volumes=None): def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
with open( with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8" "/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out: ) as f_out:
@@ -265,8 +266,11 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None):
accelerate_args = "--accelerate" accelerate_args = "--accelerate"
else: else:
accelerate_args = "--no-accelerate" accelerate_args = "--no-accelerate"
num_processes_args = ""
if num_processes := kwargs.pop("num_processes", None):
num_processes_args = f"--num-processes {num_processes}"
run_cmd( run_cmd(
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml", f"axolotl train {accelerate_args} {num_processes_args} /workspace/artifacts/axolotl/config.yaml",
run_folder, run_folder,
volumes, volumes,
) )

View File

@@ -221,7 +221,9 @@ def train(
subprocess.run(cmd, check=True) # nosec B603 subprocess.run(cmd, check=True) # nosec B603
else: else:
if cloud: if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=False) do_cli_train(
cloud_config=cloud, config=config, accelerate=False, **kwargs
)
else: else:
from axolotl.cli.train import do_cli from axolotl.cli.train import do_cli