From 050210e637a7ca2fdb65491eced13bd4d1ce5d10 Mon Sep 17 00:00:00 2001 From: goggle Date: Wed, 20 Aug 2025 09:25:20 +0900 Subject: [PATCH] fix: Sweep runs overwrite each other because output_dir from base config is reused (#3080) * refactor: improve output_dir handling in generate_config_files * fix typo * cli: harden sweep output_dir handling with base fallback - Ensure sweep permutations always resolve a valid output_dir - Default to ./model-out if neither permutation nor base config sets output_dir - Append sweepXXXX suffix consistently for each permutation - Prevent Path(None) TypeError and improve robustness of sweep config generation * fix typo * chore: lint --------- Co-authored-by: Wing Lian --- src/axolotl/cli/utils/sweeps.py | 3 ++- src/axolotl/cli/utils/train.py | 8 +++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/axolotl/cli/utils/sweeps.py b/src/axolotl/cli/utils/sweeps.py index d21664964..bb1368cf6 100644 --- a/src/axolotl/cli/utils/sweeps.py +++ b/src/axolotl/cli/utils/sweeps.py @@ -3,11 +3,12 @@ import random from copy import deepcopy from itertools import product +from typing import Any def generate_sweep_configs( base_config: dict[str, list], sweeps_config: dict[str, list] -) -> list[dict[str, list]]: +) -> list[dict[str, Any]]: """ Recursively generates all possible configurations by applying sweeps to the base config. diff --git a/src/axolotl/cli/utils/train.py b/src/axolotl/cli/utils/train.py index 31b0bcf58..b133d7271 100644 --- a/src/axolotl/cli/utils/train.py +++ b/src/axolotl/cli/utils/train.py @@ -4,6 +4,7 @@ import os import subprocess # nosec import sys import tempfile +from pathlib import Path from typing import Any, Iterator, Literal import yaml @@ -88,7 +89,12 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, # Generate all possible configurations permutations = generate_sweep_configs(base_config, sweep_config) is_group = len(permutations) > 1 - for permutation in permutations: + base_output_dir = base_config.get("output_dir", "./model-out") + for idx, permutation in enumerate(permutations, start=1): + permutation_dir = Path(permutation.get("output_dir", base_output_dir)) + permutation_id = f"sweep{idx:04d}" + permutation["output_dir"] = str(permutation_dir / permutation_id) + # pylint: disable=consider-using-with temp_file = tempfile.NamedTemporaryFile( mode="w",