[feature] sweeps (#2171)
This commit is contained in:
@@ -1,10 +1,17 @@
|
||||
"""Click CLI definitions for various axolotl commands."""
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import logging
|
||||
import random
|
||||
import subprocess # nosec B404
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
|
||||
import axolotl
|
||||
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||
@@ -20,6 +27,76 @@ from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
|
||||
def generate_sweep_configs(base_config, sweeps_config):
|
||||
"""
|
||||
Recursively generates all possible configurations by applying sweeps to the base config.
|
||||
|
||||
Args:
|
||||
base_config (dict): The original configuration dictionary
|
||||
sweeps_config (dict): Dictionary where keys are parameters and values are either:
|
||||
- lists of values to sweep independently
|
||||
- or for paired values, a list of dicts under the '_' key
|
||||
|
||||
Returns:
|
||||
list: List of all possible configuration dictionaries
|
||||
|
||||
Example:
|
||||
sweeps_config = {
|
||||
'learning_rate': [0.1, 0.01],
|
||||
'_': [
|
||||
{'load_in_8bit': True, 'adapter': 'lora'},
|
||||
{'load_in_4bit': True, 'adapter': 'qlora'}
|
||||
]
|
||||
}
|
||||
"""
|
||||
# Separate paired values from regular sweeps
|
||||
paired_values = sweeps_config.get("_", [])
|
||||
regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"}
|
||||
|
||||
# Process regular sweeps
|
||||
param_names = list(regular_sweeps.keys())
|
||||
param_values = list(regular_sweeps.values())
|
||||
|
||||
# Generate combinations for regular sweeps
|
||||
regular_combinations = list(product(*param_values)) if param_values else [()]
|
||||
|
||||
# Combine regular sweeps with paired values
|
||||
all_combinations = []
|
||||
for reg_combo in regular_combinations:
|
||||
if paired_values:
|
||||
for paired_set in paired_values:
|
||||
new_config = {}
|
||||
# new_config = deepcopy(base_config)
|
||||
# Combine regular parameters with paired parameters
|
||||
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
|
||||
for param_name, param_value in full_combo.items():
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
all_combinations.append(new_config)
|
||||
else:
|
||||
# If no paired values, just use regular combinations
|
||||
# new_config = deepcopy(base_config)
|
||||
new_config = {}
|
||||
for param_name, param_value in zip(param_names, reg_combo):
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
all_combinations.append(new_config)
|
||||
|
||||
# randomize the order of trials
|
||||
random.seed(42)
|
||||
random.shuffle(all_combinations)
|
||||
|
||||
# Generate a new config for each combination
|
||||
result_configs = []
|
||||
for combination in all_combinations:
|
||||
new_config = deepcopy(base_config)
|
||||
for param_name, param_value in combination.items():
|
||||
new_config[param_name] = param_value
|
||||
result_configs.append(new_config)
|
||||
|
||||
return result_configs
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||
def cli():
|
||||
@@ -60,10 +137,21 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
help="Use accelerate launch for multi-GPU training",
|
||||
)
|
||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--sweep",
|
||||
type=click.Path(exists=True, path_type=str),
|
||||
help="YAML config for sweeping hyperparameters",
|
||||
)
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
def train(
|
||||
config: str,
|
||||
accelerate: bool,
|
||||
cloud: Optional[str] = None,
|
||||
sweep: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Train or fine-tune a model.
|
||||
|
||||
@@ -71,6 +159,7 @@ def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs)
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
cloud: Path to a cloud accelerator configuration file
|
||||
sweep: Path to YAML config for sweeping hyperparameters.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
@@ -80,35 +169,66 @@ def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs)
|
||||
|
||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||
accelerate = False
|
||||
if sweep:
|
||||
# load the sweep configuration yaml file
|
||||
with open(sweep, "r", encoding="utf-8") as fin:
|
||||
sweep_config: dict[str, list] = yaml.safe_load(fin)
|
||||
with open(config, "r", encoding="utf-8") as fin:
|
||||
base_config: dict[str, list] = yaml.safe_load(fin)
|
||||
|
||||
if accelerate:
|
||||
if cloud:
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
|
||||
else:
|
||||
accelerate_args = []
|
||||
if "main_process_port" in kwargs:
|
||||
main_process_port = kwargs.pop("main_process_port", None)
|
||||
accelerate_args.append("--main_process_port")
|
||||
accelerate_args.append(str(main_process_port))
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes", None)
|
||||
accelerate_args.append("--num-processes")
|
||||
accelerate_args.append(str(num_processes))
|
||||
# generate all possible configurations
|
||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||
|
||||
def iter_configs():
|
||||
for perm in permutations:
|
||||
# open temp directory for temporary configurations
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
with open(
|
||||
Path(temp_dir) / "config.yaml", "w", encoding="utf-8"
|
||||
) as fout:
|
||||
yaml.dump(perm, fout)
|
||||
yield str(Path(temp_dir) / "config.yaml")
|
||||
|
||||
base_cmd = ["accelerate", "launch"]
|
||||
base_cmd.extend(accelerate_args)
|
||||
base_cmd.extend(["-m", "axolotl.cli.train"])
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
if cloud:
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
|
||||
else:
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
do_cli(config=config, **kwargs)
|
||||
def iter_configs():
|
||||
yield config
|
||||
|
||||
for cfg_file in iter_configs():
|
||||
# handle errors from subprocess so we can continue rest of sweeps
|
||||
try:
|
||||
if accelerate:
|
||||
if cloud:
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
|
||||
else:
|
||||
accelerate_args = []
|
||||
if "main_process_port" in kwargs:
|
||||
main_process_port = kwargs.pop("main_process_port", None)
|
||||
accelerate_args.append("--main_process_port")
|
||||
accelerate_args.append(str(main_process_port))
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes", None)
|
||||
accelerate_args.append("--num-processes")
|
||||
accelerate_args.append(str(num_processes))
|
||||
|
||||
base_cmd = ["accelerate", "launch"]
|
||||
base_cmd.extend(accelerate_args)
|
||||
base_cmd.extend(["-m", "axolotl.cli.train"])
|
||||
if cfg_file:
|
||||
base_cmd.append(cfg_file)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
if cloud:
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
|
||||
else:
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
do_cli(config=cfg_file, **kwargs)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||
if not sweep:
|
||||
raise exc
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
||||
68
tests/cli/test_cli_sweeps.py
Normal file
68
tests/cli/test_cli_sweeps.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
unit tests for generating sweep configurations
|
||||
"""
|
||||
from axolotl.cli.main import generate_sweep_configs
|
||||
|
||||
|
||||
def test_generate_sweep_configs_no_pairs():
|
||||
base_config = {
|
||||
"learning_rate": 0.1,
|
||||
"micro_batch_size": 1,
|
||||
"sample_packing": True,
|
||||
}
|
||||
|
||||
sweeps_config = {"micro_batch_size": [1, 2, 4], "weight_decay": [0.0, 0.1]}
|
||||
|
||||
generate_sweep_configs(base_config, sweeps_config)
|
||||
|
||||
assert len(generate_sweep_configs(base_config, sweeps_config)) == 6
|
||||
|
||||
cfg_1 = {
|
||||
"learning_rate": 0.1,
|
||||
"micro_batch_size": 2,
|
||||
"weight_decay": 0.0,
|
||||
"sample_packing": True,
|
||||
}
|
||||
|
||||
assert any(
|
||||
cfg_1 == cfg for cfg in generate_sweep_configs(base_config, sweeps_config)
|
||||
)
|
||||
|
||||
|
||||
def test_generate_sweep_configs_with_pairs():
|
||||
base_config = {
|
||||
"learning_rate": 0.1,
|
||||
"micro_batch_size": 1,
|
||||
"sample_packing": True,
|
||||
}
|
||||
|
||||
sweeps_config = {
|
||||
"_": [
|
||||
{
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
},
|
||||
{
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 4,
|
||||
},
|
||||
{
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
},
|
||||
{
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
},
|
||||
],
|
||||
"weight_decay": [0.0, 0.1],
|
||||
}
|
||||
|
||||
generate_sweep_configs(base_config, sweeps_config)
|
||||
|
||||
assert len(generate_sweep_configs(base_config, sweeps_config)) == 8
|
||||
|
||||
assert all(
|
||||
cfg["gradient_accumulation_steps"] * cfg["micro_batch_size"] == 8
|
||||
for cfg in generate_sweep_configs(base_config, sweeps_config)
|
||||
)
|
||||
Reference in New Issue
Block a user