fix: Sweep runs overwrite each other because output_dir from base config is reused (#3080)
* refactor: improve output_dir handling in generate_config_files * fix typo * cli: harden sweep output_dir handling with base fallback - Ensure sweep permutations always resolve a valid output_dir - Default to ./model-out if neither permutation nor base config sets output_dir - Append sweepXXXX suffix consistently for each permutation - Prevent Path(None) TypeError and improve robustness of sweep config generation * fix typo * chore: lint --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -3,11 +3,12 @@
|
|||||||
import random
|
import random
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from itertools import product
|
from itertools import product
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def generate_sweep_configs(
|
def generate_sweep_configs(
|
||||||
base_config: dict[str, list], sweeps_config: dict[str, list]
|
base_config: dict[str, list], sweeps_config: dict[str, list]
|
||||||
) -> list[dict[str, list]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Recursively generates all possible configurations by applying sweeps to the base config.
|
Recursively generates all possible configurations by applying sweeps to the base config.
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import os
|
|||||||
import subprocess # nosec
|
import subprocess # nosec
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Iterator, Literal
|
from typing import Any, Iterator, Literal
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
@@ -88,7 +89,12 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str,
|
|||||||
# Generate all possible configurations
|
# Generate all possible configurations
|
||||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||||
is_group = len(permutations) > 1
|
is_group = len(permutations) > 1
|
||||||
for permutation in permutations:
|
base_output_dir = base_config.get("output_dir", "./model-out")
|
||||||
|
for idx, permutation in enumerate(permutations, start=1):
|
||||||
|
permutation_dir = Path(permutation.get("output_dir", base_output_dir))
|
||||||
|
permutation_id = f"sweep{idx:04d}"
|
||||||
|
permutation["output_dir"] = str(permutation_dir / permutation_id)
|
||||||
|
|
||||||
# pylint: disable=consider-using-with
|
# pylint: disable=consider-using-with
|
||||||
temp_file = tempfile.NamedTemporaryFile(
|
temp_file = tempfile.NamedTemporaryFile(
|
||||||
mode="w",
|
mode="w",
|
||||||
|
|||||||
Reference in New Issue
Block a user