[feature] sweeps (#2171)

This commit is contained in:
Wing Lian
2025-02-01 21:11:18 -05:00
committed by GitHub
parent 80e1468b8d
commit 158330ab60
2 changed files with 214 additions and 26 deletions

View File

@@ -1,10 +1,17 @@
"""Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name
import logging
import random
import subprocess # nosec B404
import tempfile
from copy import deepcopy
from itertools import product
from pathlib import Path
from typing import Optional
import click
import yaml
import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
@@ -20,6 +27,76 @@ from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
def generate_sweep_configs(base_config, sweeps_config):
"""
Recursively generates all possible configurations by applying sweeps to the base config.
Args:
base_config (dict): The original configuration dictionary
sweeps_config (dict): Dictionary where keys are parameters and values are either:
- lists of values to sweep independently
- or for paired values, a list of dicts under the '_' key
Returns:
list: List of all possible configuration dictionaries
Example:
sweeps_config = {
'learning_rate': [0.1, 0.01],
'_': [
{'load_in_8bit': True, 'adapter': 'lora'},
{'load_in_4bit': True, 'adapter': 'qlora'}
]
}
"""
# Separate paired values from regular sweeps
paired_values = sweeps_config.get("_", [])
regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"}
# Process regular sweeps
param_names = list(regular_sweeps.keys())
param_values = list(regular_sweeps.values())
# Generate combinations for regular sweeps
regular_combinations = list(product(*param_values)) if param_values else [()]
# Combine regular sweeps with paired values
all_combinations = []
for reg_combo in regular_combinations:
if paired_values:
for paired_set in paired_values:
new_config = {}
# new_config = deepcopy(base_config)
# Combine regular parameters with paired parameters
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
for param_name, param_value in full_combo.items():
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)
else:
# If no paired values, just use regular combinations
# new_config = deepcopy(base_config)
new_config = {}
for param_name, param_value in zip(param_names, reg_combo):
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)
# randomize the order of trials
random.seed(42)
random.shuffle(all_combinations)
# Generate a new config for each combination
result_configs = []
for combination in all_combinations:
new_config = deepcopy(base_config)
for param_name, param_value in combination.items():
new_config[param_name] = param_value
result_configs.append(new_config)
return result_configs
@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli():
@@ -60,10 +137,21 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
help="Use accelerate launch for multi-GPU training",
)
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
@click.option(
"--sweep",
type=click.Path(exists=True, path_type=str),
help="YAML config for sweeping hyperparameters",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs) -> None:
def train(
config: str,
accelerate: bool,
cloud: Optional[str] = None,
sweep: Optional[str] = None,
**kwargs,
) -> None:
"""
Train or fine-tune a model.
@@ -71,6 +159,7 @@ def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs)
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
cloud: Path to a cloud accelerator configuration file
sweep: Path to YAML config for sweeping hyperparameters.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
@@ -80,35 +169,66 @@ def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs)
if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False
if sweep:
# load the sweep configuration yaml file
with open(sweep, "r", encoding="utf-8") as fin:
sweep_config: dict[str, list] = yaml.safe_load(fin)
with open(config, "r", encoding="utf-8") as fin:
base_config: dict[str, list] = yaml.safe_load(fin)
if accelerate:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
else:
accelerate_args = []
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port", None)
accelerate_args.append("--main_process_port")
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num-processes")
accelerate_args.append(str(num_processes))
# generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)
def iter_configs():
for perm in permutations:
# open temp directory for temporary configurations
with tempfile.TemporaryDirectory() as temp_dir:
with open(
Path(temp_dir) / "config.yaml", "w", encoding="utf-8"
) as fout:
yaml.dump(perm, fout)
yield str(Path(temp_dir) / "config.yaml")
base_cmd = ["accelerate", "launch"]
base_cmd.extend(accelerate_args)
base_cmd.extend(["-m", "axolotl.cli.train"])
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
else:
from axolotl.cli.train import do_cli
do_cli(config=config, **kwargs)
def iter_configs():
yield config
for cfg_file in iter_configs():
# handle errors from subprocess so we can continue rest of sweeps
try:
if accelerate:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
else:
accelerate_args = []
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port", None)
accelerate_args.append("--main_process_port")
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num-processes")
accelerate_args.append(str(num_processes))
base_cmd = ["accelerate", "launch"]
base_cmd.extend(accelerate_args)
base_cmd.extend(["-m", "axolotl.cli.train"])
if cfg_file:
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
else:
from axolotl.cli.train import do_cli
do_cli(config=cfg_file, **kwargs)
except subprocess.CalledProcessError as exc:
logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep:
raise exc
@cli.command()

View File

@@ -0,0 +1,68 @@
"""
unit tests for generating sweep configurations
"""
from axolotl.cli.main import generate_sweep_configs
def test_generate_sweep_configs_no_pairs():
base_config = {
"learning_rate": 0.1,
"micro_batch_size": 1,
"sample_packing": True,
}
sweeps_config = {"micro_batch_size": [1, 2, 4], "weight_decay": [0.0, 0.1]}
generate_sweep_configs(base_config, sweeps_config)
assert len(generate_sweep_configs(base_config, sweeps_config)) == 6
cfg_1 = {
"learning_rate": 0.1,
"micro_batch_size": 2,
"weight_decay": 0.0,
"sample_packing": True,
}
assert any(
cfg_1 == cfg for cfg in generate_sweep_configs(base_config, sweeps_config)
)
def test_generate_sweep_configs_with_pairs():
base_config = {
"learning_rate": 0.1,
"micro_batch_size": 1,
"sample_packing": True,
}
sweeps_config = {
"_": [
{
"micro_batch_size": 1,
"gradient_accumulation_steps": 8,
},
{
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
},
{
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
},
{
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
},
],
"weight_decay": [0.0, 0.1],
}
generate_sweep_configs(base_config, sweeps_config)
assert len(generate_sweep_configs(base_config, sweeps_config)) == 8
assert all(
cfg["gradient_accumulation_steps"] * cfg["micro_batch_size"] == 8
for cfg in generate_sweep_configs(base_config, sweeps_config)
)