diff --git a/src/axolotl/cli/utils/sweeps.py b/src/axolotl/cli/utils/sweeps.py index d21664964..bb1368cf6 100644 --- a/src/axolotl/cli/utils/sweeps.py +++ b/src/axolotl/cli/utils/sweeps.py @@ -3,11 +3,12 @@ import random from copy import deepcopy from itertools import product +from typing import Any def generate_sweep_configs( base_config: dict[str, list], sweeps_config: dict[str, list] -) -> list[dict[str, list]]: +) -> list[dict[str, Any]]: """ Recursively generates all possible configurations by applying sweeps to the base config. diff --git a/src/axolotl/cli/utils/train.py b/src/axolotl/cli/utils/train.py index 31b0bcf58..b133d7271 100644 --- a/src/axolotl/cli/utils/train.py +++ b/src/axolotl/cli/utils/train.py @@ -4,6 +4,7 @@ import os import subprocess # nosec import sys import tempfile +from pathlib import Path from typing import Any, Iterator, Literal import yaml @@ -88,7 +89,12 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, # Generate all possible configurations permutations = generate_sweep_configs(base_config, sweep_config) is_group = len(permutations) > 1 - for permutation in permutations: + base_output_dir = base_config.get("output_dir", "./model-out") + for idx, permutation in enumerate(permutations, start=1): + permutation_dir = Path(permutation.get("output_dir", base_output_dir)) + permutation_id = f"sweep{idx:04d}" + permutation["output_dir"] = str(permutation_dir / permutation_id) + # pylint: disable=consider-using-with temp_file = tempfile.NamedTemporaryFile( mode="w",