[feature] sweeps (#2171)
This commit is contained in:
@@ -1,10 +1,17 @@
|
|||||||
"""Click CLI definitions for various axolotl commands."""
|
"""Click CLI definitions for various axolotl commands."""
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
import subprocess # nosec B404
|
import subprocess # nosec B404
|
||||||
|
import tempfile
|
||||||
|
from copy import deepcopy
|
||||||
|
from itertools import product
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
import yaml
|
||||||
|
|
||||||
import axolotl
|
import axolotl
|
||||||
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||||
@@ -20,6 +27,76 @@ from axolotl.utils import set_pytorch_cuda_alloc_conf
|
|||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||||
|
|
||||||
|
|
||||||
|
def generate_sweep_configs(base_config, sweeps_config):
|
||||||
|
"""
|
||||||
|
Recursively generates all possible configurations by applying sweeps to the base config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_config (dict): The original configuration dictionary
|
||||||
|
sweeps_config (dict): Dictionary where keys are parameters and values are either:
|
||||||
|
- lists of values to sweep independently
|
||||||
|
- or for paired values, a list of dicts under the '_' key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of all possible configuration dictionaries
|
||||||
|
|
||||||
|
Example:
|
||||||
|
sweeps_config = {
|
||||||
|
'learning_rate': [0.1, 0.01],
|
||||||
|
'_': [
|
||||||
|
{'load_in_8bit': True, 'adapter': 'lora'},
|
||||||
|
{'load_in_4bit': True, 'adapter': 'qlora'}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# Separate paired values from regular sweeps
|
||||||
|
paired_values = sweeps_config.get("_", [])
|
||||||
|
regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"}
|
||||||
|
|
||||||
|
# Process regular sweeps
|
||||||
|
param_names = list(regular_sweeps.keys())
|
||||||
|
param_values = list(regular_sweeps.values())
|
||||||
|
|
||||||
|
# Generate combinations for regular sweeps
|
||||||
|
regular_combinations = list(product(*param_values)) if param_values else [()]
|
||||||
|
|
||||||
|
# Combine regular sweeps with paired values
|
||||||
|
all_combinations = []
|
||||||
|
for reg_combo in regular_combinations:
|
||||||
|
if paired_values:
|
||||||
|
for paired_set in paired_values:
|
||||||
|
new_config = {}
|
||||||
|
# new_config = deepcopy(base_config)
|
||||||
|
# Combine regular parameters with paired parameters
|
||||||
|
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
|
||||||
|
for param_name, param_value in full_combo.items():
|
||||||
|
new_config[param_name] = param_value
|
||||||
|
print(new_config)
|
||||||
|
all_combinations.append(new_config)
|
||||||
|
else:
|
||||||
|
# If no paired values, just use regular combinations
|
||||||
|
# new_config = deepcopy(base_config)
|
||||||
|
new_config = {}
|
||||||
|
for param_name, param_value in zip(param_names, reg_combo):
|
||||||
|
new_config[param_name] = param_value
|
||||||
|
print(new_config)
|
||||||
|
all_combinations.append(new_config)
|
||||||
|
|
||||||
|
# randomize the order of trials
|
||||||
|
random.seed(42)
|
||||||
|
random.shuffle(all_combinations)
|
||||||
|
|
||||||
|
# Generate a new config for each combination
|
||||||
|
result_configs = []
|
||||||
|
for combination in all_combinations:
|
||||||
|
new_config = deepcopy(base_config)
|
||||||
|
for param_name, param_value in combination.items():
|
||||||
|
new_config[param_name] = param_value
|
||||||
|
result_configs.append(new_config)
|
||||||
|
|
||||||
|
return result_configs
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||||
def cli():
|
def cli():
|
||||||
@@ -60,10 +137,21 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
|||||||
help="Use accelerate launch for multi-GPU training",
|
help="Use accelerate launch for multi-GPU training",
|
||||||
)
|
)
|
||||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||||
|
@click.option(
|
||||||
|
"--sweep",
|
||||||
|
type=click.Path(exists=True, path_type=str),
|
||||||
|
help="YAML config for sweeping hyperparameters",
|
||||||
|
)
|
||||||
@add_options_from_dataclass(TrainerCliArgs)
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
@filter_none_kwargs
|
@filter_none_kwargs
|
||||||
def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs) -> None:
|
def train(
|
||||||
|
config: str,
|
||||||
|
accelerate: bool,
|
||||||
|
cloud: Optional[str] = None,
|
||||||
|
sweep: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Train or fine-tune a model.
|
Train or fine-tune a model.
|
||||||
|
|
||||||
@@ -71,6 +159,7 @@ def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs)
|
|||||||
config: Path to `axolotl` config YAML file.
|
config: Path to `axolotl` config YAML file.
|
||||||
accelerate: Whether to use `accelerate` launcher.
|
accelerate: Whether to use `accelerate` launcher.
|
||||||
cloud: Path to a cloud accelerator configuration file
|
cloud: Path to a cloud accelerator configuration file
|
||||||
|
sweep: Path to YAML config for sweeping hyperparameters.
|
||||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||||
config options.
|
config options.
|
||||||
"""
|
"""
|
||||||
@@ -80,35 +169,66 @@ def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs)
|
|||||||
|
|
||||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||||
accelerate = False
|
accelerate = False
|
||||||
|
if sweep:
|
||||||
|
# load the sweep configuration yaml file
|
||||||
|
with open(sweep, "r", encoding="utf-8") as fin:
|
||||||
|
sweep_config: dict[str, list] = yaml.safe_load(fin)
|
||||||
|
with open(config, "r", encoding="utf-8") as fin:
|
||||||
|
base_config: dict[str, list] = yaml.safe_load(fin)
|
||||||
|
|
||||||
if accelerate:
|
# generate all possible configurations
|
||||||
if cloud:
|
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||||
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
|
|
||||||
else:
|
def iter_configs():
|
||||||
accelerate_args = []
|
for perm in permutations:
|
||||||
if "main_process_port" in kwargs:
|
# open temp directory for temporary configurations
|
||||||
main_process_port = kwargs.pop("main_process_port", None)
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
accelerate_args.append("--main_process_port")
|
with open(
|
||||||
accelerate_args.append(str(main_process_port))
|
Path(temp_dir) / "config.yaml", "w", encoding="utf-8"
|
||||||
if "num_processes" in kwargs:
|
) as fout:
|
||||||
num_processes = kwargs.pop("num_processes", None)
|
yaml.dump(perm, fout)
|
||||||
accelerate_args.append("--num-processes")
|
yield str(Path(temp_dir) / "config.yaml")
|
||||||
accelerate_args.append(str(num_processes))
|
|
||||||
|
|
||||||
base_cmd = ["accelerate", "launch"]
|
|
||||||
base_cmd.extend(accelerate_args)
|
|
||||||
base_cmd.extend(["-m", "axolotl.cli.train"])
|
|
||||||
if config:
|
|
||||||
base_cmd.append(config)
|
|
||||||
cmd = build_command(base_cmd, kwargs)
|
|
||||||
subprocess.run(cmd, check=True) # nosec B603
|
|
||||||
else:
|
else:
|
||||||
if cloud:
|
|
||||||
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
|
|
||||||
else:
|
|
||||||
from axolotl.cli.train import do_cli
|
|
||||||
|
|
||||||
do_cli(config=config, **kwargs)
|
def iter_configs():
|
||||||
|
yield config
|
||||||
|
|
||||||
|
for cfg_file in iter_configs():
|
||||||
|
# handle errors from subprocess so we can continue rest of sweeps
|
||||||
|
try:
|
||||||
|
if accelerate:
|
||||||
|
if cloud:
|
||||||
|
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
|
||||||
|
else:
|
||||||
|
accelerate_args = []
|
||||||
|
if "main_process_port" in kwargs:
|
||||||
|
main_process_port = kwargs.pop("main_process_port", None)
|
||||||
|
accelerate_args.append("--main_process_port")
|
||||||
|
accelerate_args.append(str(main_process_port))
|
||||||
|
if "num_processes" in kwargs:
|
||||||
|
num_processes = kwargs.pop("num_processes", None)
|
||||||
|
accelerate_args.append("--num-processes")
|
||||||
|
accelerate_args.append(str(num_processes))
|
||||||
|
|
||||||
|
base_cmd = ["accelerate", "launch"]
|
||||||
|
base_cmd.extend(accelerate_args)
|
||||||
|
base_cmd.extend(["-m", "axolotl.cli.train"])
|
||||||
|
if cfg_file:
|
||||||
|
base_cmd.append(cfg_file)
|
||||||
|
cmd = build_command(base_cmd, kwargs)
|
||||||
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
|
else:
|
||||||
|
if cloud:
|
||||||
|
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
|
||||||
|
else:
|
||||||
|
from axolotl.cli.train import do_cli
|
||||||
|
|
||||||
|
do_cli(config=cfg_file, **kwargs)
|
||||||
|
except subprocess.CalledProcessError as exc:
|
||||||
|
logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||||
|
if not sweep:
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
|
|||||||
68
tests/cli/test_cli_sweeps.py
Normal file
68
tests/cli/test_cli_sweeps.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""
|
||||||
|
unit tests for generating sweep configurations
|
||||||
|
"""
|
||||||
|
from axolotl.cli.main import generate_sweep_configs
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_sweep_configs_no_pairs():
|
||||||
|
base_config = {
|
||||||
|
"learning_rate": 0.1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"sample_packing": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
sweeps_config = {"micro_batch_size": [1, 2, 4], "weight_decay": [0.0, 0.1]}
|
||||||
|
|
||||||
|
generate_sweep_configs(base_config, sweeps_config)
|
||||||
|
|
||||||
|
assert len(generate_sweep_configs(base_config, sweeps_config)) == 6
|
||||||
|
|
||||||
|
cfg_1 = {
|
||||||
|
"learning_rate": 0.1,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"sample_packing": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert any(
|
||||||
|
cfg_1 == cfg for cfg in generate_sweep_configs(base_config, sweeps_config)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_sweep_configs_with_pairs():
|
||||||
|
base_config = {
|
||||||
|
"learning_rate": 0.1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"sample_packing": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
sweeps_config = {
|
||||||
|
"_": [
|
||||||
|
{
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 8,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"weight_decay": [0.0, 0.1],
|
||||||
|
}
|
||||||
|
|
||||||
|
generate_sweep_configs(base_config, sweeps_config)
|
||||||
|
|
||||||
|
assert len(generate_sweep_configs(base_config, sweeps_config)) == 8
|
||||||
|
|
||||||
|
assert all(
|
||||||
|
cfg["gradient_accumulation_steps"] * cfg["micro_batch_size"] == 8
|
||||||
|
for cfg in generate_sweep_configs(base_config, sweeps_config)
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user