From 158330ab60086d9948472dc8e8530a330511a4bf Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 1 Feb 2025 21:11:18 -0500 Subject: [PATCH] [feature] sweeps (#2171) --- src/axolotl/cli/main.py | 172 +++++++++++++++++++++++++++++------ tests/cli/test_cli_sweeps.py | 68 ++++++++++++++ 2 files changed, 214 insertions(+), 26 deletions(-) create mode 100644 tests/cli/test_cli_sweeps.py diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index e8551511e..d7aa1f6a7 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -1,10 +1,17 @@ """Click CLI definitions for various axolotl commands.""" # pylint: disable=redefined-outer-name +import logging +import random import subprocess # nosec B404 +import tempfile +from copy import deepcopy +from itertools import product +from pathlib import Path from typing import Optional import click +import yaml import axolotl from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs @@ -20,6 +27,76 @@ from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig +def generate_sweep_configs(base_config, sweeps_config): + """ + Recursively generates all possible configurations by applying sweeps to the base config. + + Args: + base_config (dict): The original configuration dictionary + sweeps_config (dict): Dictionary where keys are parameters and values are either: + - lists of values to sweep independently + - or for paired values, a list of dicts under the '_' key + + Returns: + list: List of all possible configuration dictionaries + + Example: + sweeps_config = { + 'learning_rate': [0.1, 0.01], + '_': [ + {'load_in_8bit': True, 'adapter': 'lora'}, + {'load_in_4bit': True, 'adapter': 'qlora'} + ] + } + """ + # Separate paired values from regular sweeps + paired_values = sweeps_config.get("_", []) + regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"} + + # Process regular sweeps + param_names = list(regular_sweeps.keys()) + param_values = list(regular_sweeps.values()) + + # Generate combinations for regular sweeps + regular_combinations = list(product(*param_values)) if param_values else [()] + + # Combine regular sweeps with paired values + all_combinations = [] + for reg_combo in regular_combinations: + if paired_values: + for paired_set in paired_values: + new_config = {} + # new_config = deepcopy(base_config) + # Combine regular parameters with paired parameters + full_combo = {**dict(zip(param_names, reg_combo)), **paired_set} + for param_name, param_value in full_combo.items(): + new_config[param_name] = param_value + print(new_config) + all_combinations.append(new_config) + else: + # If no paired values, just use regular combinations + # new_config = deepcopy(base_config) + new_config = {} + for param_name, param_value in zip(param_names, reg_combo): + new_config[param_name] = param_value + print(new_config) + all_combinations.append(new_config) + + # randomize the order of trials + random.seed(42) + random.shuffle(all_combinations) + + # Generate a new config for each combination + result_configs = [] + for combination in all_combinations: + new_config = deepcopy(base_config) + for param_name, param_value in combination.items(): + new_config[param_name] = param_value + result_configs.append(new_config) + + return result_configs + + @click.group() @click.version_option(version=axolotl.__version__, prog_name="axolotl") def cli(): @@ -60,10 +137,21 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None: help="Use accelerate launch for multi-GPU training", ) @click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str)) +@click.option( + "--sweep", + type=click.Path(exists=True, path_type=str), + help="YAML config for sweeping hyperparameters", +) @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) @filter_none_kwargs -def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs) -> None: +def train( + config: str, + accelerate: bool, + cloud: Optional[str] = None, + sweep: Optional[str] = None, + **kwargs, +) -> None: """ Train or fine-tune a model. @@ -71,6 +159,7 @@ def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs) config: Path to `axolotl` config YAML file. accelerate: Whether to use `accelerate` launcher. cloud: Path to a cloud accelerator configuration file + sweep: Path to YAML config for sweeping hyperparameters. kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ @@ -80,35 +169,66 @@ def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs) if "use_ray" in kwargs and kwargs["use_ray"]: accelerate = False + if sweep: + # load the sweep configuration yaml file + with open(sweep, "r", encoding="utf-8") as fin: + sweep_config: dict[str, list] = yaml.safe_load(fin) + with open(config, "r", encoding="utf-8") as fin: + base_config: dict[str, list] = yaml.safe_load(fin) - if accelerate: - if cloud: - do_cli_train(cloud_config=cloud, config=config, accelerate=True) - else: - accelerate_args = [] - if "main_process_port" in kwargs: - main_process_port = kwargs.pop("main_process_port", None) - accelerate_args.append("--main_process_port") - accelerate_args.append(str(main_process_port)) - if "num_processes" in kwargs: - num_processes = kwargs.pop("num_processes", None) - accelerate_args.append("--num-processes") - accelerate_args.append(str(num_processes)) + # generate all possible configurations + permutations = generate_sweep_configs(base_config, sweep_config) + + def iter_configs(): + for perm in permutations: + # open temp directory for temporary configurations + with tempfile.TemporaryDirectory() as temp_dir: + with open( + Path(temp_dir) / "config.yaml", "w", encoding="utf-8" + ) as fout: + yaml.dump(perm, fout) + yield str(Path(temp_dir) / "config.yaml") - base_cmd = ["accelerate", "launch"] - base_cmd.extend(accelerate_args) - base_cmd.extend(["-m", "axolotl.cli.train"]) - if config: - base_cmd.append(config) - cmd = build_command(base_cmd, kwargs) - subprocess.run(cmd, check=True) # nosec B603 else: - if cloud: - do_cli_train(cloud_config=cloud, config=config, accelerate=False) - else: - from axolotl.cli.train import do_cli - do_cli(config=config, **kwargs) + def iter_configs(): + yield config + + for cfg_file in iter_configs(): + # handle errors from subprocess so we can continue rest of sweeps + try: + if accelerate: + if cloud: + do_cli_train(cloud_config=cloud, config=config, accelerate=True) + else: + accelerate_args = [] + if "main_process_port" in kwargs: + main_process_port = kwargs.pop("main_process_port", None) + accelerate_args.append("--main_process_port") + accelerate_args.append(str(main_process_port)) + if "num_processes" in kwargs: + num_processes = kwargs.pop("num_processes", None) + accelerate_args.append("--num-processes") + accelerate_args.append(str(num_processes)) + + base_cmd = ["accelerate", "launch"] + base_cmd.extend(accelerate_args) + base_cmd.extend(["-m", "axolotl.cli.train"]) + if cfg_file: + base_cmd.append(cfg_file) + cmd = build_command(base_cmd, kwargs) + subprocess.run(cmd, check=True) # nosec B603 + else: + if cloud: + do_cli_train(cloud_config=cloud, config=config, accelerate=False) + else: + from axolotl.cli.train import do_cli + + do_cli(config=cfg_file, **kwargs) + except subprocess.CalledProcessError as exc: + logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}") + if not sweep: + raise exc @cli.command() diff --git a/tests/cli/test_cli_sweeps.py b/tests/cli/test_cli_sweeps.py new file mode 100644 index 000000000..61c886e80 --- /dev/null +++ b/tests/cli/test_cli_sweeps.py @@ -0,0 +1,68 @@ +""" +unit tests for generating sweep configurations +""" +from axolotl.cli.main import generate_sweep_configs + + +def test_generate_sweep_configs_no_pairs(): + base_config = { + "learning_rate": 0.1, + "micro_batch_size": 1, + "sample_packing": True, + } + + sweeps_config = {"micro_batch_size": [1, 2, 4], "weight_decay": [0.0, 0.1]} + + generate_sweep_configs(base_config, sweeps_config) + + assert len(generate_sweep_configs(base_config, sweeps_config)) == 6 + + cfg_1 = { + "learning_rate": 0.1, + "micro_batch_size": 2, + "weight_decay": 0.0, + "sample_packing": True, + } + + assert any( + cfg_1 == cfg for cfg in generate_sweep_configs(base_config, sweeps_config) + ) + + +def test_generate_sweep_configs_with_pairs(): + base_config = { + "learning_rate": 0.1, + "micro_batch_size": 1, + "sample_packing": True, + } + + sweeps_config = { + "_": [ + { + "micro_batch_size": 1, + "gradient_accumulation_steps": 8, + }, + { + "micro_batch_size": 2, + "gradient_accumulation_steps": 4, + }, + { + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + }, + { + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + }, + ], + "weight_decay": [0.0, 0.1], + } + + generate_sweep_configs(base_config, sweeps_config) + + assert len(generate_sweep_configs(base_config, sweeps_config)) == 8 + + assert all( + cfg["gradient_accumulation_steps"] * cfg["micro_batch_size"] == 8 + for cfg in generate_sweep_configs(base_config, sweeps_config) + )