make sure to handle num-processes with cloud

This commit is contained in:
Wing Lian
2025-02-06 13:50:39 -05:00
parent ecea44c902
commit c82cbdc6d9
3 changed files with 12 additions and 5 deletions

View File

@@ -35,13 +35,14 @@ def do_cli_train(
cloud_config: Union[Path, str],
config: Union[Path, str],
accelerate: bool = True,
**kwargs,
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.train(config_yaml, accelerate=accelerate)
cloud.train(config_yaml, accelerate=accelerate, **kwargs)
def do_cli_lm_eval(

View File

@@ -217,7 +217,7 @@ class ModalCloud(Cloud):
secrets=self.get_secrets(),
)
def train(self, config_yaml: str, accelerate: bool = True):
def train(self, config_yaml: str, accelerate: bool = True, **kwargs):
modal_fn = self.get_train_env()(_train)
with modal.enable_output():
with self.app.run(detach=True):
@@ -225,6 +225,7 @@ class ModalCloud(Cloud):
config_yaml,
accelerate=accelerate,
volumes={k: v[0] for k, v in self.volumes.items()},
**kwargs,
)
def lm_eval(self, config_yaml: str):
@@ -255,7 +256,7 @@ def _preprocess(config_yaml: str, volumes=None):
)
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
@@ -265,8 +266,11 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None):
accelerate_args = "--accelerate"
else:
accelerate_args = "--no-accelerate"
num_processes_args = ""
if num_processes := kwargs.pop("num_processes", None):
num_processes_args = f"--num-processes {num_processes}"
run_cmd(
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
f"axolotl train {accelerate_args} {num_processes_args} /workspace/artifacts/axolotl/config.yaml",
run_folder,
volumes,
)

View File

@@ -221,7 +221,9 @@ def train(
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
do_cli_train(
cloud_config=cloud, config=config, accelerate=False, **kwargs
)
else:
from axolotl.cli.train import do_cli