Compare commits

...

6 Commits

Author SHA1 Message Date
Wing Lian
6905711e45 set max steps to -1 when empty 2025-02-06 17:27:52 -05:00
Wing Lian
bb5a6135eb don't set total num steps for grpo 2025-02-06 17:23:13 -05:00
Wing Lian
e637f9b1a4 cleanup pythonpath if axo in it 2025-02-06 17:03:21 -05:00
Wing Lian
1a3bfd6e0f test not deleting pythonpath for custom code bundling
clean path and add mounts
handle mounting
2025-02-06 17:01:19 -05:00
Wing Lian
3df4df868c make sure to pass kwargs when using accelerate 2025-02-06 14:00:15 -05:00
Wing Lian
c82cbdc6d9 make sure to handle num-processes with cloud 2025-02-06 13:50:39 -05:00
5 changed files with 49 additions and 14 deletions

View File

@@ -35,13 +35,18 @@ def do_cli_train(
cloud_config: Union[Path, str],
config: Union[Path, str],
accelerate: bool = True,
cwd=None,
**kwargs,
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.train(config_yaml, accelerate=accelerate)
local_dirs = {}
if cwd and not Path(cwd).joinpath("src", "axolotl").exists():
local_dirs = {"/workspace/mounts": cwd}
cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs)
def do_cli_lm_eval(

View File

@@ -7,6 +7,7 @@ import os
import subprocess # nosec B404
from pathlib import Path
from random import randint
from typing import Optional
import modal
@@ -22,10 +23,17 @@ def run_cmd(cmd: str, run_folder: str, volumes=None):
# modal workaround so it doesn't use the automounted axolotl
new_env = copy.deepcopy(os.environ)
if "PYTHONPATH" in new_env:
python_path = Path(new_env["PYTHONPATH"].split(":")[0])
if python_path.joinpath("src", "axolotl").exists():
# we don't want to use the automounted axolotl or unexpected behavior happens
paths = ["/workspace/mounts"]
for sub_python_path_str in new_env["PYTHONPATH"].split(":"):
sub_python_path = Path(sub_python_path_str)
if not sub_python_path.joinpath("src", "axolotl").exists():
# we don't want to use the automounted axolotl or unexpected behavior happens
paths.append(str(sub_python_path))
if paths:
new_env["PYTHONPATH"] = ":".join(paths)
else:
del new_env["PYTHONPATH"]
# Propagate errors from subprocess.
@@ -206,9 +214,12 @@ class ModalCloud(Cloud):
memory = int(self.config.memory)
return 1024 * memory
def get_train_env(self):
def get_train_env(self, local_dirs=None):
image = self.get_image()
for mount, local_dir in (local_dirs or {}).items():
image = image.add_local_dir(local_dir, mount)
return self.app.function(
image=self.get_image(),
image=image,
volumes={k: v[0] for k, v in self.volumes.items()},
cpu=16.0,
gpu=self.get_train_gpu(),
@@ -217,14 +228,21 @@ class ModalCloud(Cloud):
secrets=self.get_secrets(),
)
def train(self, config_yaml: str, accelerate: bool = True):
modal_fn = self.get_train_env()(_train)
def train(
self,
config_yaml: str,
accelerate: bool = True,
local_dirs: Optional[dict[str, str]] = None,
**kwargs,
):
modal_fn = self.get_train_env(local_dirs)(_train)
with modal.enable_output():
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
accelerate=accelerate,
volumes={k: v[0] for k, v in self.volumes.items()},
**kwargs,
)
def lm_eval(self, config_yaml: str):
@@ -255,7 +273,7 @@ def _preprocess(config_yaml: str, volumes=None):
)
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
@@ -265,8 +283,11 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None):
accelerate_args = "--accelerate"
else:
accelerate_args = "--no-accelerate"
num_processes_args = ""
if num_processes := kwargs.pop("num_processes", None):
num_processes_args = f"--num-processes {num_processes}"
run_cmd(
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
f"axolotl train {accelerate_args} {num_processes_args} /workspace/artifacts/axolotl/config.yaml",
run_folder,
volumes,
)

View File

@@ -2,6 +2,7 @@
# pylint: disable=redefined-outer-name
import logging
import os
import random
import subprocess # nosec B404
import tempfile
@@ -200,7 +201,10 @@ def train(
try:
if accelerate:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
cwd = os.getcwd()
do_cli_train(
cloud_config=cloud, config=config, accelerate=True, cwd=cwd, **kwargs
)
else:
accelerate_args = []
if "main_process_port" in kwargs:
@@ -221,7 +225,9 @@ def train(
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
do_cli_train(
cloud_config=cloud, config=config, accelerate=False, **kwargs
)
else:
from axolotl.cli.train import do_cli

View File

@@ -122,9 +122,11 @@ def load_preference_datasets(
`total_num_steps`.
"""
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
total_num_steps = int(
total_num_steps: Optional[int] = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cfg.rl == "grpo":
total_num_steps = None
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")

View File

@@ -1032,10 +1032,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
max_steps = self.cfg.max_steps or total_num_steps or -1
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
self.cfg.output_dir,
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=self.cfg.max_steps or total_num_steps,
max_steps=max_steps,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate,
warmup_steps=self.cfg.warmup_steps,