CLI: add --launcher option, support launcher args, cleanup, refactor (#2924)
* add --launcher option; explicit True/False bool args; small cleanup * refactor * add torchrun, accelerate cli args * add rdzv arg default + tests * update _quarto * coderabbit * fix * we can't set rdvz_id independently across nodes * coderabbit * fix tests
This commit is contained in:
@@ -30,8 +30,6 @@ class TrainerCliArgs:
|
||||
debug_num_examples: int = field(default=0)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
shard: bool = field(default=False)
|
||||
main_process_port: Optional[int] = field(default=None)
|
||||
num_processes: Optional[int] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -3,7 +3,7 @@ launch axolotl in supported cloud platforms
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Literal
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -11,7 +11,7 @@ from axolotl.cli.cloud.modal_ import ModalCloud
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
|
||||
def load_cloud_cfg(cloud_config: Path | str) -> DictDefault:
|
||||
"""Load and validate cloud configuration."""
|
||||
# Load cloud configuration.
|
||||
with open(cloud_config, encoding="utf-8") as file:
|
||||
@@ -20,8 +20,8 @@ def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
|
||||
|
||||
|
||||
def do_cli_preprocess(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
cloud_config: Path | str,
|
||||
config: Path | str,
|
||||
) -> None:
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
@@ -31,9 +31,10 @@ def do_cli_preprocess(
|
||||
|
||||
|
||||
def do_cli_train(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
accelerate: bool = True,
|
||||
cloud_config: Path | str,
|
||||
config: Path | str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
cwd=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@@ -44,12 +45,18 @@ def do_cli_train(
|
||||
local_dirs = {}
|
||||
if cwd and not Path(cwd).joinpath("src", "axolotl").exists():
|
||||
local_dirs = {"/workspace/mounts": cwd}
|
||||
cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs)
|
||||
cloud.train(
|
||||
config_yaml,
|
||||
launcher=launcher,
|
||||
launcher_args=launcher_args,
|
||||
local_dirs=local_dirs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def do_cli_lm_eval(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
cloud_config: Path | str,
|
||||
config: Path | str,
|
||||
) -> None:
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
|
||||
@@ -3,6 +3,7 @@ base class for cloud platforms from cli
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class Cloud(ABC):
|
||||
@@ -15,5 +16,12 @@ class Cloud(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def train(self, config_yaml: str, accelerate: bool = True) -> str:
|
||||
def train(
|
||||
self,
|
||||
config_yaml: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
local_dirs: dict[str, str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -8,7 +8,7 @@ import os
|
||||
import subprocess # nosec B404
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
from typing import Optional
|
||||
from typing import Literal
|
||||
|
||||
import modal
|
||||
|
||||
@@ -230,8 +230,9 @@ class ModalCloud(Cloud):
|
||||
def train(
|
||||
self,
|
||||
config_yaml: str,
|
||||
accelerate: bool = True,
|
||||
local_dirs: Optional[dict[str, str]] = None,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
local_dirs: dict[str, str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
modal_fn = self.get_train_env(local_dirs)(_train)
|
||||
@@ -239,7 +240,8 @@ class ModalCloud(Cloud):
|
||||
with self.app.run(detach=True):
|
||||
modal_fn.remote(
|
||||
config_yaml,
|
||||
accelerate=accelerate,
|
||||
launcher=launcher,
|
||||
launcher_args=launcher_args,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
**kwargs,
|
||||
)
|
||||
@@ -270,20 +272,35 @@ def _preprocess(config_yaml: str, volumes=None):
|
||||
)
|
||||
|
||||
|
||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||
def _train(
|
||||
config_yaml: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
volumes=None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
|
||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/mounts"
|
||||
if accelerate:
|
||||
accelerate_args = "--accelerate"
|
||||
|
||||
launcher_args = launcher_args or []
|
||||
|
||||
# Build the base command
|
||||
if launcher == "accelerate":
|
||||
launcher_arg = "--launcher accelerate"
|
||||
elif launcher == "torchrun":
|
||||
launcher_arg = "--launcher torchrun"
|
||||
else:
|
||||
accelerate_args = "--no-accelerate"
|
||||
num_processes_args = ""
|
||||
if num_processes := kwargs.pop("num_processes", None):
|
||||
num_processes_args = f"--num-processes {num_processes}"
|
||||
launcher_arg = "--launcher python"
|
||||
|
||||
# Build launcher args string
|
||||
launcher_args_str = ""
|
||||
if launcher_args:
|
||||
launcher_args_str = "-- " + " ".join(launcher_args)
|
||||
|
||||
run_cmd(
|
||||
f"axolotl train {accelerate_args} {num_processes_args} /workspace/mounts/config.yaml",
|
||||
f"axolotl train {launcher_arg} /workspace/mounts/config.yaml {launcher_args_str}".strip(),
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
@@ -197,14 +197,13 @@ def load_cfg(
|
||||
# If there are any options passed in the cli, if it is something that seems valid
|
||||
# from the yaml, then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
for k, _ in kwargs.items():
|
||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||
if k in cfg_keys or not cfg.strict:
|
||||
# handle booleans
|
||||
if isinstance(cfg[k], bool):
|
||||
cfg[k] = bool(kwargs[k])
|
||||
for key, value in kwargs.items():
|
||||
# If not strict, allow writing to cfg even if it's not in the yml already
|
||||
if key in cfg_keys or not cfg.strict:
|
||||
if isinstance(cfg[key], bool):
|
||||
cfg[key] = bool(value)
|
||||
else:
|
||||
cfg[k] = kwargs[k]
|
||||
cfg[key] = value
|
||||
|
||||
try:
|
||||
device_props = torch.cuda.get_device_properties("cuda")
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Generator, Union
|
||||
import fire
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from dotenv import load_dotenv
|
||||
from transformers import AutoProcessor
|
||||
|
||||
|
||||
@@ -152,5 +151,4 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -5,7 +5,6 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
@@ -13,7 +12,6 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.evaluate import evaluate
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -30,9 +28,6 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: CLI arguments.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
check_accelerate_default_config()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
@@ -64,5 +59,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Union
|
||||
import fire
|
||||
import torch
|
||||
import transformers
|
||||
from dotenv import load_dotenv
|
||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||
|
||||
from axolotl.cli.args import InferenceCliArgs
|
||||
@@ -268,5 +267,4 @@ def do_cli(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -4,12 +4,9 @@
|
||||
|
||||
import os
|
||||
import subprocess # nosec B404
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import axolotl
|
||||
@@ -21,13 +18,14 @@ from axolotl.cli.args import (
|
||||
VllmServeCliArgs,
|
||||
)
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.sweeps import generate_sweep_configs
|
||||
from axolotl.cli.utils import (
|
||||
add_options_from_config,
|
||||
add_options_from_dataclass,
|
||||
build_command,
|
||||
fetch_from_github,
|
||||
filter_none_kwargs,
|
||||
generate_config_files,
|
||||
launch_training,
|
||||
)
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import patch_optimized_env
|
||||
@@ -36,12 +34,19 @@ from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
LAUNCHER_COMMAND_MAPPING = {
|
||||
"accelerate": ["accelerate", "launch"],
|
||||
"torchrun": ["torchrun"],
|
||||
}
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||
def cli():
|
||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||
print_axolotl_text_art()
|
||||
load_dotenv()
|
||||
patch_optimized_env()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@@ -50,7 +55,7 @@ def cli():
|
||||
@add_options_from_dataclass(PreprocessCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Preprocess datasets before training.
|
||||
|
||||
@@ -60,7 +65,6 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
patch_optimized_env()
|
||||
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_preprocess
|
||||
@@ -72,12 +76,15 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=True,
|
||||
help="Use accelerate launch for multi-GPU training",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for multi-GPU training",
|
||||
)
|
||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
@@ -88,126 +95,81 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
@click.pass_context
|
||||
def train(
|
||||
ctx: click.Context,
|
||||
config: str,
|
||||
accelerate: bool,
|
||||
cloud: Optional[str] = None,
|
||||
sweep: Optional[str] = None,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
cloud: str | None = None,
|
||||
sweep: str | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Train or fine-tune a model.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for multi-GPU training ("accelerate", "torchrun", or "python").
|
||||
cloud: Path to a cloud accelerator configuration file
|
||||
sweep: Path to YAML config for sweeping hyperparameters.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||
accelerate = False
|
||||
if sweep:
|
||||
# load the sweep configuration yaml file
|
||||
with open(sweep, "r", encoding="utf-8") as fin:
|
||||
sweep_config: dict[str, list] = yaml.safe_load(fin)
|
||||
with open(config, "r", encoding="utf-8") as fin:
|
||||
base_config: dict[str, list] = yaml.safe_load(fin)
|
||||
# Handle Ray launcher override
|
||||
_launcher = None if kwargs.get("use_ray") else launcher
|
||||
|
||||
# generate all possible configurations
|
||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||
|
||||
def iter_configs():
|
||||
for perm in permutations:
|
||||
# open temp directory for temporary configurations
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
with open(
|
||||
Path(temp_dir) / "config.yaml", "w", encoding="utf-8"
|
||||
) as fout:
|
||||
yaml.dump(perm, fout)
|
||||
yield str(Path(temp_dir) / "config.yaml")
|
||||
|
||||
else:
|
||||
|
||||
def iter_configs():
|
||||
yield config
|
||||
|
||||
for cfg_file in iter_configs():
|
||||
# handle errors from subprocess so we can continue rest of sweeps
|
||||
# Process each configuration
|
||||
for cfg_file in generate_config_files(config, sweep):
|
||||
try:
|
||||
if accelerate:
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
cwd = os.getcwd()
|
||||
do_cli_train(
|
||||
cloud_config=cloud,
|
||||
config=config,
|
||||
accelerate=True,
|
||||
cwd=cwd,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
accelerate_args = []
|
||||
if "main_process_port" in kwargs:
|
||||
main_process_port = kwargs.pop("main_process_port", None)
|
||||
accelerate_args.append("--main_process_port")
|
||||
accelerate_args.append(str(main_process_port))
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes", None)
|
||||
accelerate_args.append("--num_processes")
|
||||
accelerate_args.append(str(num_processes))
|
||||
|
||||
base_cmd = ["accelerate", "launch"]
|
||||
base_cmd.extend(accelerate_args)
|
||||
base_cmd.extend(["-m", "axolotl.cli.train"])
|
||||
if cfg_file:
|
||||
base_cmd.append(cfg_file)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
do_cli_train(
|
||||
cloud_config=cloud, config=config, accelerate=False, **kwargs
|
||||
)
|
||||
else:
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
do_cli(config=cfg_file, **kwargs)
|
||||
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||
if not sweep:
|
||||
raise exc
|
||||
finally:
|
||||
# Only delete temp files, not the original config
|
||||
if cfg_file != config:
|
||||
os.unlink(cfg_file)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=True,
|
||||
help="Use accelerate launch for multi-GPU training",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for multi-GPU evaluation",
|
||||
)
|
||||
@add_options_from_dataclass(EvaluateCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def evaluate(config: str, accelerate: bool, **kwargs) -> None:
|
||||
@click.pass_context
|
||||
def evaluate(ctx: click.Context, config: str, launcher: str, **kwargs):
|
||||
"""
|
||||
Evaluate a model.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for multi-GPU evaluation ("accelerate", "torchrun", or "python").
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
if accelerate:
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if launcher in LAUNCHER_COMMAND_MAPPING:
|
||||
base_cmd = (
|
||||
LAUNCHER_COMMAND_MAPPING[launcher]
|
||||
+ launcher_args
|
||||
+ ["-m", "axolotl.cli.evaluate"]
|
||||
)
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
@@ -218,30 +180,42 @@ def evaluate(config: str, accelerate: bool, **kwargs) -> None:
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=False,
|
||||
help="Use accelerate launch for multi-GPU inference",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for multi-GPU inference",
|
||||
)
|
||||
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
|
||||
@click.pass_context
|
||||
def inference(ctx: click.Context, config: str, launcher: str, gradio: bool, **kwargs):
|
||||
"""
|
||||
Run inference with a trained model.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for multi-GPU inference ("accelerate", "torchrun", or "python").
|
||||
gradio: Whether to use Gradio browser interface or command line for inference.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
if accelerate:
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if launcher in LAUNCHER_COMMAND_MAPPING:
|
||||
base_cmd = (
|
||||
LAUNCHER_COMMAND_MAPPING[launcher]
|
||||
+ launcher_args
|
||||
+ ["-m", "axolotl.cli.inference"]
|
||||
)
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
if gradio:
|
||||
@@ -254,33 +228,42 @@ def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
|
||||
do_cli(config=config, gradio=gradio, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=True,
|
||||
help="Use accelerate launch for weight merging",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for weight merging",
|
||||
)
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
|
||||
@click.pass_context
|
||||
def merge_sharded_fsdp_weights(
|
||||
ctx: click.Context, config: str, launcher: str, **kwargs
|
||||
):
|
||||
"""
|
||||
Merge sharded FSDP model weights.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for weight merging ("accelerate", "torchrun", or "python").
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
if accelerate:
|
||||
base_cmd = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"-m",
|
||||
"axolotl.cli.merge_sharded_fsdp_weights",
|
||||
]
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if launcher in LAUNCHER_COMMAND_MAPPING:
|
||||
base_cmd = (
|
||||
LAUNCHER_COMMAND_MAPPING[launcher]
|
||||
+ launcher_args
|
||||
+ ["-m", "axolotl.cli.merge_sharded_fsdp_weights"]
|
||||
)
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
@@ -296,7 +279,7 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def merge_lora(config: str, **kwargs) -> None:
|
||||
def merge_lora(config: str, **kwargs):
|
||||
"""
|
||||
Merge trained LoRA adapters into a base model.
|
||||
|
||||
@@ -313,7 +296,7 @@ def merge_lora(config: str, **kwargs) -> None:
|
||||
@cli.command()
|
||||
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
|
||||
@click.option("--dest", help="Destination directory")
|
||||
def fetch(directory: str, dest: Optional[str]) -> None:
|
||||
def fetch(directory: str, dest: Optional[str]):
|
||||
"""
|
||||
Fetch example configs or other resources.
|
||||
|
||||
@@ -351,7 +334,7 @@ def quantize(config: str, **cli_args: QuantizeCliArgs):
|
||||
@cli.command()
|
||||
@click.argument("model", type=click.Path(exists=True, path_type=str))
|
||||
@click.argument("output", type=click.Path(exists=False, path_type=str))
|
||||
def delinearize_llama4(model: str, output: str) -> None:
|
||||
def delinearize_llama4(model: str, output: str):
|
||||
from axolotl.cli.delinearize_llama4 import do_cli as do_delinearize_llama4
|
||||
|
||||
do_delinearize_llama4(model, output)
|
||||
@@ -365,5 +348,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
main()
|
||||
|
||||
@@ -4,7 +4,6 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
@@ -88,5 +87,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -17,7 +17,6 @@ from accelerate.utils import (
|
||||
WEIGHTS_NAME,
|
||||
is_torch_version,
|
||||
)
|
||||
from dotenv import load_dotenv
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||
@@ -204,5 +203,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -9,7 +9,6 @@ import fire
|
||||
import transformers
|
||||
from accelerate import init_empty_weights
|
||||
from colorama import Fore
|
||||
from dotenv import load_dotenv
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.cli.args import PreprocessCliArgs
|
||||
@@ -109,5 +108,4 @@ def do_cli(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Union
|
||||
|
||||
import fire
|
||||
from accelerate import Accelerator
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
@@ -16,7 +15,6 @@ from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.train import train
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -31,9 +29,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Training-specific CLI arguments.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
|
||||
check_accelerate_default_config()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
check_user_token()
|
||||
@@ -122,5 +117,4 @@ def ray_train_func(kwargs: dict):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -1,330 +0,0 @@
|
||||
"""Utility methods for axolotl CLI."""
|
||||
|
||||
import concurrent.futures
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import json
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from types import NoneType
|
||||
from typing import Any, Callable, Type, Union, get_args, get_origin
|
||||
|
||||
import click
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.loaders.model import ModelLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def strip_optional_type(field_type: type | str | None):
|
||||
"""
|
||||
Extracts the non-`None` type from an `Optional` / `Union` type.
|
||||
|
||||
Args:
|
||||
field_type: Type of field for Axolotl CLI command.
|
||||
|
||||
Returns:
|
||||
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
|
||||
returns the input type unchanged.
|
||||
"""
|
||||
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||
field_type = next(
|
||||
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||
)
|
||||
|
||||
return field_type
|
||||
|
||||
|
||||
def filter_none_kwargs(func: Callable) -> Callable:
|
||||
"""
|
||||
Wraps function to remove `None`-valued `kwargs`.
|
||||
|
||||
Args:
|
||||
func: Function to wrap.
|
||||
|
||||
Returns:
|
||||
Wrapped function.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Callable:
|
||||
"""Filters out `None`-valued `kwargs`."""
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
return func(*args, **filtered_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a dataclass.
|
||||
|
||||
Args:
|
||||
config_class: Dataclass with fields to parse from the CLI.
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process dataclass fields in reverse order for correct option ordering
|
||||
for field in reversed(dataclasses.fields(config_class)):
|
||||
field_type = strip_optional_type(field.type)
|
||||
|
||||
if field_type == bool:
|
||||
field_name = field.name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{field.name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
type=field_type,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a Pydantic model.
|
||||
|
||||
Args:
|
||||
config_class: PyDantic model with fields to parse from the CLI
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process model fields in reverse order for correct option ordering
|
||||
for name, field in reversed(config_class.model_fields.items()):
|
||||
field_type = strip_optional_type(field.annotation)
|
||||
|
||||
if field_type == bool:
|
||||
field_name = name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Build command list from base command and options.
|
||||
|
||||
Args:
|
||||
base_cmd: Command without options.
|
||||
options: Options to parse and append to base command.
|
||||
|
||||
Returns:
|
||||
List of strings giving shell command.
|
||||
"""
|
||||
cmd = base_cmd.copy()
|
||||
|
||||
for key, value in options.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
key = key.replace("_", "-")
|
||||
|
||||
if isinstance(value, bool):
|
||||
if value:
|
||||
cmd.append(f"--{key}")
|
||||
else:
|
||||
cmd.extend([f"--{key}", str(value)])
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def download_file(
|
||||
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Download a single file and return its processing status.
|
||||
|
||||
Args:
|
||||
file_info: Tuple of (file_path, remote_sha).
|
||||
raw_base_url: Base URL for raw GitHub content.
|
||||
dest_path: Local destination directory.
|
||||
dir_prefix: Directory prefix to filter files.
|
||||
|
||||
Returns:
|
||||
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
|
||||
"""
|
||||
file_path, remote_sha = file_info
|
||||
raw_url = f"{raw_base_url}/{file_path}"
|
||||
dest_file = dest_path / file_path.split(dir_prefix)[-1]
|
||||
|
||||
# Check if file exists and needs updating
|
||||
if dest_file.exists():
|
||||
with open(dest_file, "rb") as file:
|
||||
content = file.read()
|
||||
# Calculate git blob SHA
|
||||
blob = b"blob " + str(len(content)).encode() + b"\0" + content
|
||||
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
|
||||
|
||||
if local_sha == remote_sha:
|
||||
print(f"Skipping {file_path} (unchanged)")
|
||||
return file_path, "unchanged"
|
||||
|
||||
print(f"Updating {file_path}")
|
||||
status = "new"
|
||||
else:
|
||||
print(f"Downloading {file_path}")
|
||||
status = "new"
|
||||
|
||||
# Create directories if needed
|
||||
dest_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download and save file
|
||||
try:
|
||||
response = requests.get(raw_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(dest_file, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
return file_path, status
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error downloading {file_path}: {str(request_error)}")
|
||||
return file_path, "error"
|
||||
|
||||
|
||||
def fetch_from_github(
|
||||
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
|
||||
) -> None:
|
||||
"""
|
||||
Sync files from a specific directory in the GitHub repository.
|
||||
Only downloads files that don't exist locally or have changed.
|
||||
|
||||
Args:
|
||||
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
|
||||
'deepspeed_configs/').
|
||||
dest_dir: Local destination directory.
|
||||
max_workers: Maximum number of concurrent downloads.
|
||||
"""
|
||||
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
|
||||
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
|
||||
|
||||
# Get repository tree with timeout
|
||||
response = requests.get(api_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
tree = json.loads(response.text)
|
||||
|
||||
# Filter for files and get their SHA
|
||||
files = {
|
||||
item["path"]: item["sha"]
|
||||
for item in tree["tree"]
|
||||
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
|
||||
}
|
||||
|
||||
if not files:
|
||||
raise click.ClickException(f"No files found in {dir_prefix}")
|
||||
|
||||
# Default destination directory is the last part of dir_prefix
|
||||
default_dest = Path(dir_prefix.rstrip("/"))
|
||||
dest_path = Path(dest_dir) if dest_dir else default_dest
|
||||
|
||||
# Keep track of processed files for summary
|
||||
files_processed: dict[str, list[str]] = {
|
||||
"new": [],
|
||||
"updated": [],
|
||||
"unchanged": [],
|
||||
"error": [],
|
||||
}
|
||||
|
||||
# Process files in parallel using ThreadPoolExecutor
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_file = {
|
||||
executor.submit(
|
||||
download_file,
|
||||
(file_path, remote_sha),
|
||||
raw_base_url,
|
||||
dest_path,
|
||||
dir_prefix,
|
||||
): file_path
|
||||
for file_path, remote_sha in files.items()
|
||||
}
|
||||
|
||||
# Process completed tasks as they finish
|
||||
for future in concurrent.futures.as_completed(future_to_file):
|
||||
file_path = future_to_file[future]
|
||||
try:
|
||||
file_path, status = future.result()
|
||||
files_processed[status].append(file_path)
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error processing {file_path}: {str(request_error)}")
|
||||
files_processed["error"].append(file_path)
|
||||
|
||||
# Log summary
|
||||
LOG.info("\nSync Summary:")
|
||||
LOG.info(f"New files: {len(files_processed['new'])}")
|
||||
LOG.info(f"Updated files: {len(files_processed['updated'])}")
|
||||
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
|
||||
if files_processed["error"]:
|
||||
LOG.info(f"Failed files: {len(files_processed['error'])}")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
inference: bool = False,
|
||||
) -> tuple[
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
|
||||
ProcessorMixin | None,
|
||||
]:
|
||||
"""
|
||||
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
|
||||
config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
inference: Boolean denoting inference mode.
|
||||
|
||||
Returns:
|
||||
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
|
||||
"""
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
LOG.info("loading model...")
|
||||
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
|
||||
model, _ = model_loader.load()
|
||||
|
||||
processor = None
|
||||
if cfg.is_multimodal:
|
||||
LOG.info("loading processor...")
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
return model, tokenizer, processor
|
||||
23
src/axolotl/cli/utils/__init__.py
Normal file
23
src/axolotl/cli/utils/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Init for axolotl.cli.utils module."""
|
||||
|
||||
from .args import (
|
||||
add_options_from_config,
|
||||
add_options_from_dataclass,
|
||||
filter_none_kwargs,
|
||||
)
|
||||
from .fetch import fetch_from_github
|
||||
from .load import load_model_and_tokenizer
|
||||
from .sweeps import generate_sweep_configs
|
||||
from .train import build_command, generate_config_files, launch_training
|
||||
|
||||
__all__ = [
|
||||
"filter_none_kwargs",
|
||||
"add_options_from_dataclass",
|
||||
"add_options_from_config",
|
||||
"build_command",
|
||||
"generate_config_files",
|
||||
"generate_sweep_configs",
|
||||
"load_model_and_tokenizer",
|
||||
"launch_training",
|
||||
"fetch_from_github",
|
||||
]
|
||||
120
src/axolotl/cli/utils/args.py
Normal file
120
src/axolotl/cli/utils/args.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Utilities for axolotl CLI args."""
|
||||
|
||||
import dataclasses
|
||||
from functools import wraps
|
||||
from types import NoneType
|
||||
from typing import Any, Callable, Type, Union, get_args, get_origin
|
||||
|
||||
import click
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def _strip_optional_type(field_type: type | str | None):
|
||||
"""
|
||||
Extracts the non-`None` type from an `Optional` / `Union` type.
|
||||
|
||||
Args:
|
||||
field_type: Type of field for Axolotl CLI command.
|
||||
|
||||
Returns:
|
||||
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
|
||||
returns the input type unchanged.
|
||||
"""
|
||||
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||
field_type = next(
|
||||
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||
)
|
||||
|
||||
return field_type
|
||||
|
||||
|
||||
def filter_none_kwargs(func: Callable) -> Callable:
|
||||
"""
|
||||
Wraps function to remove `None`-valued `kwargs`.
|
||||
|
||||
Args:
|
||||
func: Function to wrap.
|
||||
|
||||
Returns:
|
||||
Wrapped function.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Callable:
|
||||
"""Filters out `None`-valued `kwargs`."""
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
return func(*args, **filtered_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a dataclass.
|
||||
|
||||
Args:
|
||||
config_class: Dataclass with fields to parse from the CLI.
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process dataclass fields in reverse order for correct option ordering
|
||||
for field in reversed(dataclasses.fields(config_class)):
|
||||
field_type = _strip_optional_type(field.type)
|
||||
|
||||
if field_type == bool:
|
||||
field_name = field.name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{field.name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
type=field_type,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a Pydantic model.
|
||||
|
||||
Args:
|
||||
config_class: PyDantic model with fields to parse from the CLI
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process model fields in reverse order for correct option ordering
|
||||
for name, field in reversed(config_class.model_fields.items()):
|
||||
field_type = _strip_optional_type(field.annotation)
|
||||
|
||||
if field_type == bool:
|
||||
field_name = name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
142
src/axolotl/cli/utils/fetch.py
Normal file
142
src/axolotl/cli/utils/fetch.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Utilities for axolotl fetch CLI command."""
|
||||
|
||||
import concurrent.futures
|
||||
import hashlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import requests
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _download_file(
|
||||
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Download a single file and return its processing status.
|
||||
|
||||
Args:
|
||||
file_info: Tuple of (file_path, remote_sha).
|
||||
raw_base_url: Base URL for raw GitHub content.
|
||||
dest_path: Local destination directory.
|
||||
dir_prefix: Directory prefix to filter files.
|
||||
|
||||
Returns:
|
||||
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
|
||||
"""
|
||||
file_path, remote_sha = file_info
|
||||
raw_url = f"{raw_base_url}/{file_path}"
|
||||
dest_file = dest_path / file_path.split(dir_prefix)[-1]
|
||||
|
||||
# Check if file exists and needs updating
|
||||
if dest_file.exists():
|
||||
with open(dest_file, "rb") as file:
|
||||
content = file.read()
|
||||
# Calculate git blob SHA
|
||||
blob = b"blob " + str(len(content)).encode() + b"\0" + content
|
||||
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
|
||||
|
||||
if local_sha == remote_sha:
|
||||
print(f"Skipping {file_path} (unchanged)")
|
||||
return file_path, "unchanged"
|
||||
|
||||
print(f"Updating {file_path}")
|
||||
status = "updated"
|
||||
else:
|
||||
print(f"Downloading {file_path}")
|
||||
status = "new"
|
||||
|
||||
# Create directories if needed
|
||||
dest_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download and save file
|
||||
try:
|
||||
response = requests.get(raw_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(dest_file, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
return file_path, status
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error downloading {file_path}: {str(request_error)}")
|
||||
return file_path, "error"
|
||||
|
||||
|
||||
def fetch_from_github(
|
||||
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
|
||||
) -> None:
|
||||
"""
|
||||
Sync files from a specific directory in the GitHub repository.
|
||||
Only downloads files that don't exist locally or have changed.
|
||||
|
||||
Args:
|
||||
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
|
||||
'deepspeed_configs/').
|
||||
dest_dir: Local destination directory.
|
||||
max_workers: Maximum number of concurrent downloads.
|
||||
"""
|
||||
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
|
||||
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
|
||||
|
||||
# Get repository tree with timeout
|
||||
response = requests.get(api_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
tree = json.loads(response.text)
|
||||
|
||||
# Filter for files and get their SHA
|
||||
files = {
|
||||
item["path"]: item["sha"]
|
||||
for item in tree["tree"]
|
||||
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
|
||||
}
|
||||
|
||||
if not files:
|
||||
raise click.ClickException(f"No files found in {dir_prefix}")
|
||||
|
||||
# Default destination directory is the last part of dir_prefix
|
||||
default_dest = Path(dir_prefix.rstrip("/"))
|
||||
dest_path = Path(dest_dir) if dest_dir else default_dest
|
||||
|
||||
# Keep track of processed files for summary
|
||||
files_processed: dict[str, list[str]] = {
|
||||
"new": [],
|
||||
"updated": [],
|
||||
"unchanged": [],
|
||||
"error": [],
|
||||
}
|
||||
|
||||
# Process files in parallel using ThreadPoolExecutor
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_file = {
|
||||
executor.submit(
|
||||
_download_file,
|
||||
(file_path, remote_sha),
|
||||
raw_base_url,
|
||||
dest_path,
|
||||
dir_prefix,
|
||||
): file_path
|
||||
for file_path, remote_sha in files.items()
|
||||
}
|
||||
|
||||
# Process completed tasks as they finish
|
||||
for future in concurrent.futures.as_completed(future_to_file):
|
||||
file_path = future_to_file[future]
|
||||
try:
|
||||
file_path, status = future.result()
|
||||
files_processed[status].append(file_path)
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error processing {file_path}: {str(request_error)}")
|
||||
files_processed["error"].append(file_path)
|
||||
|
||||
# Log summary
|
||||
LOG.info("\nSync Summary:")
|
||||
LOG.info(f"New files: {len(files_processed['new'])}")
|
||||
LOG.info(f"Updated files: {len(files_processed['updated'])}")
|
||||
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
|
||||
if files_processed["error"]:
|
||||
LOG.info(f"Failed files: {len(files_processed['error'])}")
|
||||
52
src/axolotl/cli/utils/load.py
Normal file
52
src/axolotl/cli/utils/load.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Utilities for model, tokenizer, etc. loading."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.loaders.model import ModelLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
inference: bool = False,
|
||||
) -> tuple[
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
|
||||
ProcessorMixin | None,
|
||||
]:
|
||||
"""
|
||||
Helper function for loading a model, tokenizer, and processor specified in the
|
||||
given `axolotl` config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
inference: Boolean denoting inference mode.
|
||||
|
||||
Returns:
|
||||
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
|
||||
"""
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
LOG.info("loading model...")
|
||||
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
|
||||
model, _ = model_loader.load()
|
||||
|
||||
processor = None
|
||||
if cfg.is_multimodal:
|
||||
LOG.info("loading processor...")
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
return model, tokenizer, processor
|
||||
188
src/axolotl/cli/utils/train.py
Normal file
188
src/axolotl/cli/utils/train.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""Utilities for axolotl train CLI command."""
|
||||
|
||||
import os
|
||||
import subprocess # nosec
|
||||
import tempfile
|
||||
from typing import Any, Iterator, Literal
|
||||
|
||||
import yaml
|
||||
|
||||
from axolotl.cli.utils.sweeps import generate_sweep_configs
|
||||
|
||||
|
||||
def _add_default_rdzv_args(launcher_args: list[str]) -> list[str]:
|
||||
"""
|
||||
Add default RDZV arguments if rdzv_endpoint is set but rdzv_backend/rdzv_id are missing.
|
||||
|
||||
Args:
|
||||
launcher_args: List of launcher arguments
|
||||
|
||||
Returns:
|
||||
Updated launcher args with defaults added if needed
|
||||
"""
|
||||
args = launcher_args.copy()
|
||||
|
||||
# Check if rdzv_endpoint is present
|
||||
has_rdzv_endpoint = any("--rdzv_endpoint" in arg for arg in args)
|
||||
|
||||
if has_rdzv_endpoint:
|
||||
# Check if rdzv_backend is already provided
|
||||
has_rdzv_backend = any("--rdzv_backend" in arg for arg in args)
|
||||
if not has_rdzv_backend:
|
||||
args.extend(["--rdzv_backend", "c10d"])
|
||||
|
||||
# Check if rdzv_id is already provided
|
||||
has_rdzv_id = any("--rdzv_id" in arg for arg in args)
|
||||
if not has_rdzv_id:
|
||||
import uuid
|
||||
|
||||
args.extend(["--rdzv_id", str(uuid.uuid4())[:8]])
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Build command list from base command and options.
|
||||
|
||||
Args:
|
||||
base_cmd: Command without options.
|
||||
options: Options to parse and append to base command.
|
||||
|
||||
Returns:
|
||||
List of strings giving shell command.
|
||||
"""
|
||||
cmd = base_cmd.copy()
|
||||
|
||||
for key, value in options.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
key = key.replace("_", "-")
|
||||
cmd.append(f"--{key}={value}")
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
||||
"""Generate list of configuration files to process."""
|
||||
if not sweep:
|
||||
yield config
|
||||
return
|
||||
|
||||
# Load sweep and base configurations
|
||||
with open(sweep, "r", encoding="utf-8") as fin:
|
||||
sweep_config: dict[str, list] = yaml.safe_load(fin)
|
||||
with open(config, "r", encoding="utf-8") as fin:
|
||||
base_config: dict[str, list] = yaml.safe_load(fin)
|
||||
|
||||
# Generate all possible configurations
|
||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||
for permutation in permutations:
|
||||
# pylint: disable=consider-using-with
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
suffix=".yaml",
|
||||
delete=False,
|
||||
encoding="utf-8",
|
||||
)
|
||||
yaml.dump(permutation, temp_file)
|
||||
temp_file.close()
|
||||
yield temp_file.name
|
||||
|
||||
|
||||
def launch_training(
|
||||
cfg_file: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] | None,
|
||||
cloud: str | None,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Execute training with the given configuration."""
|
||||
launcher_args = launcher_args or []
|
||||
|
||||
if cloud:
|
||||
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
|
||||
elif launcher:
|
||||
if launcher == "accelerate":
|
||||
_launch_accelerate_training(cfg_file, kwargs, launcher_args)
|
||||
elif launcher == "torchrun":
|
||||
_launch_torchrun_training(cfg_file, kwargs, launcher_args)
|
||||
elif launcher == "python":
|
||||
_launch_python_training(cfg_file, kwargs)
|
||||
|
||||
|
||||
def _launch_cloud_training(
|
||||
cloud: str,
|
||||
cfg_file: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] | None,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Execute training via cloud launcher."""
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
launcher_args = launcher_args or []
|
||||
cwd = os.getcwd() if launcher else None
|
||||
|
||||
do_cli_train(
|
||||
cloud_config=cloud,
|
||||
config=cfg_file,
|
||||
launcher=launcher or "accelerate",
|
||||
launcher_args=launcher_args,
|
||||
cwd=cwd,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _launch_accelerate_training(
|
||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
||||
) -> None:
|
||||
"""Execute training via accelerate launcher."""
|
||||
launcher_args = launcher_args or []
|
||||
internal_launcher_args = []
|
||||
|
||||
# Extract launcher-specific arguments from kwargs (legacy support)
|
||||
if "main_process_port" in kwargs:
|
||||
main_process_port = kwargs.pop("main_process_port")
|
||||
internal_launcher_args.extend(["--main_process_port", str(main_process_port)])
|
||||
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes")
|
||||
internal_launcher_args.extend(["--num_processes", str(num_processes)])
|
||||
|
||||
# Combine internal args with user-provided launcher args
|
||||
all_launcher_args = internal_launcher_args + launcher_args
|
||||
|
||||
base_cmd = (
|
||||
["accelerate", "launch"] + all_launcher_args + ["-m", "axolotl.cli.train"]
|
||||
)
|
||||
if cfg_file:
|
||||
base_cmd.append(cfg_file)
|
||||
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
|
||||
|
||||
def _launch_torchrun_training(
|
||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
||||
) -> None:
|
||||
"""Execute training via torchrun launcher."""
|
||||
launcher_args = launcher_args or []
|
||||
|
||||
# Add default RDZV arguments if rdzv_endpoint is set
|
||||
launcher_args = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
base_cmd = ["torchrun"] + launcher_args + ["-m", "axolotl.cli.train"]
|
||||
if cfg_file:
|
||||
base_cmd.append(cfg_file)
|
||||
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
|
||||
|
||||
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
|
||||
"""Execute training via python launcher."""
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
do_cli(config=cfg_file, **kwargs)
|
||||
Reference in New Issue
Block a user