make sure to handle num-processes with cloud
This commit is contained in:
@@ -35,13 +35,14 @@ def do_cli_train(
|
|||||||
cloud_config: Union[Path, str],
|
cloud_config: Union[Path, str],
|
||||||
config: Union[Path, str],
|
config: Union[Path, str],
|
||||||
accelerate: bool = True,
|
accelerate: bool = True,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
config_yaml = file.read()
|
config_yaml = file.read()
|
||||||
cloud.train(config_yaml, accelerate=accelerate)
|
cloud.train(config_yaml, accelerate=accelerate, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def do_cli_lm_eval(
|
def do_cli_lm_eval(
|
||||||
|
|||||||
@@ -217,7 +217,7 @@ class ModalCloud(Cloud):
|
|||||||
secrets=self.get_secrets(),
|
secrets=self.get_secrets(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def train(self, config_yaml: str, accelerate: bool = True):
|
def train(self, config_yaml: str, accelerate: bool = True, **kwargs):
|
||||||
modal_fn = self.get_train_env()(_train)
|
modal_fn = self.get_train_env()(_train)
|
||||||
with modal.enable_output():
|
with modal.enable_output():
|
||||||
with self.app.run(detach=True):
|
with self.app.run(detach=True):
|
||||||
@@ -225,6 +225,7 @@ class ModalCloud(Cloud):
|
|||||||
config_yaml,
|
config_yaml,
|
||||||
accelerate=accelerate,
|
accelerate=accelerate,
|
||||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def lm_eval(self, config_yaml: str):
|
def lm_eval(self, config_yaml: str):
|
||||||
@@ -255,7 +256,7 @@ def _preprocess(config_yaml: str, volumes=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
|
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||||
with open(
|
with open(
|
||||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||||
) as f_out:
|
) as f_out:
|
||||||
@@ -265,8 +266,11 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None):
|
|||||||
accelerate_args = "--accelerate"
|
accelerate_args = "--accelerate"
|
||||||
else:
|
else:
|
||||||
accelerate_args = "--no-accelerate"
|
accelerate_args = "--no-accelerate"
|
||||||
|
num_processes_args = ""
|
||||||
|
if num_processes := kwargs.pop("num_processes", None):
|
||||||
|
num_processes_args = f"--num-processes {num_processes}"
|
||||||
run_cmd(
|
run_cmd(
|
||||||
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
|
f"axolotl train {accelerate_args} {num_processes_args} /workspace/artifacts/axolotl/config.yaml",
|
||||||
run_folder,
|
run_folder,
|
||||||
volumes,
|
volumes,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -221,7 +221,9 @@ def train(
|
|||||||
subprocess.run(cmd, check=True) # nosec B603
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
else:
|
else:
|
||||||
if cloud:
|
if cloud:
|
||||||
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
|
do_cli_train(
|
||||||
|
cloud_config=cloud, config=config, accelerate=False, **kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
from axolotl.cli.train import do_cli
|
from axolotl.cli.train import do_cli
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user