diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index ea702cd8c..59408f151 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -4,7 +4,6 @@ from __future__ import annotations import functools import os -import random from pathlib import Path from typing import TYPE_CHECKING, Any, Generator @@ -541,28 +540,21 @@ def merge_datasets( if len(datasets) == 1: ds = datasets[0] - # Do not shuffle if curriculum sampling is enabled or - # shuffle_merged_datasets is disabled - if cfg.curriculum_sampling or not cfg.shuffle_merged_datasets: - return ds - - # Only shuffle regular datasets, not IterableDatasets - if isinstance(ds, IterableDataset): + if ( + cfg.curriculum_sampling + or not cfg.shuffle_merged_datasets + or isinstance(ds, IterableDataset) + ): return ds return ds.shuffle(seed=cfg.seed) - if any(isinstance(ds, IterableDataset) for ds in datasets): - LOG.info("Merging streaming datasets...") - merged_dataset = _merge_streaming_datasets(datasets, cfg) - else: - # If enabled, shuffle each dataset independently before merging. - # This allows curriculum learning strategies to be applied at the dataset level. - if cfg.shuffle_before_merging_datasets: - LOG.info("Shuffling each dataset individually before merging...") - datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets] + if cfg.shuffle_before_merging_datasets and all( + isinstance(ds, Dataset) for ds in datasets + ): + LOG.info("Shuffling each dataset individually before merging...") + datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets] - LOG.info("Merging datasets...") - merged_dataset = _merge_regular_datasets(datasets, cfg) + merged_dataset = _merge_datasets_with_strategy(datasets, cfg) if cfg.shuffle_merged_datasets and not isinstance(merged_dataset, IterableDataset): LOG.debug("Shuffling merged datasets...") @@ -581,144 +573,39 @@ def merge_datasets( return merged_dataset -def _merge_streaming_datasets( +def _merge_datasets_with_strategy( datasets: list[Dataset | IterableDataset], cfg: DictDefault -) -> IterableDataset: +) -> Dataset | IterableDataset: """ - Merge streaming datasets using the configured mixing strategy. + Merge datasets using the configured mixing strategy. Works with streaming and non- + streaming datasets. Args: - datasets: List of datasets to merge (at least one must be IterableDataset). - cfg: Configuration object containing streaming mixing settings. + datasets: List of datasets to merge. + cfg: Configuration object containing mixing settings. Returns: - Merged IterableDataset. + Merged dataset (Dataset or IterableDataset depending on inputs). """ - # Get mixing configuration - strategy = cfg.get("dataset_mixing_strategy", "round_robin") + strategy = cfg.get("dataset_mixing_strategy", "concatenate") weights = cfg.get("mixing_weights", None) - LOG.info(f"Using streaming mixing strategy: {strategy}") + LOG.info(f"Merging datasets with mixing strategy: {strategy}...") + if strategy == "concatenate": + # Concatenate only works with non-iterable datasets + if not all(isinstance(ds, Dataset) for ds in datasets): + raise ValueError( + "Cannot concatenate streaming datasets. Use 'round_robin', 'weighted', " + "or 'random' instead." + ) + return concatenate_datasets(datasets) if strategy == "round_robin": return interleave_datasets(datasets, seed=cfg.seed) if strategy == "weighted": return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed) - return interleave_datasets( - datasets, probabilities=[1.0 / len(datasets)] * len(datasets), seed=cfg.seed - ) - - -def _merge_regular_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset: - """ - Merge regular (non-streaming) datasets using the configured mixing strategy. - - Args: - datasets: List of regular datasets to merge. - cfg: Configuration object containing mixing settings. - - Returns: - Merged Dataset. - """ - # Get mixing configuration - strategy = cfg.get("dataset_mixing_strategy", "concatenate") - weights = cfg.get("mixing_weights", None) - - LOG.info(f"Using dataset mixing strategy: {strategy}") - - if strategy == "concatenate": - return concatenate_datasets(datasets) - if strategy == "round_robin": - return _interleave_regular_datasets_round_robin(datasets, cfg.seed) - if strategy == "weighted": - return _interleave_regular_datasets_weighted(datasets, weights, cfg.seed) - equal_weights = [1.0 / len(datasets)] * len(datasets) - return _interleave_regular_datasets_weighted(datasets, equal_weights, cfg.seed) - - -def _interleave_regular_datasets_round_robin( - datasets: list[Dataset], seed: int -) -> Dataset: - """Interleave regular datasets in round-robin fashion.""" - # Create indices for each dataset - dataset_indices = [] - for i, dataset in enumerate(datasets): - indices = [(i, j) for j in range(len(dataset))] - dataset_indices.extend(indices) - - # Interleave round-robin style - max_len = max(len(ds) for ds in datasets) - interleaved_indices = [] - - for pos in range(max_len): - for ds_idx, dataset in enumerate(datasets): - if pos < len(dataset): - interleaved_indices.append((ds_idx, pos)) - - # Create new dataset with interleaved samples - def generate_samples(): - for ds_idx, sample_idx in interleaved_indices: - yield datasets[ds_idx][sample_idx] - - # Convert to Dataset - samples = list(generate_samples()) - if not samples: - return concatenate_datasets(datasets) # Fallback - - # Create dataset from samples - first_sample = samples[0] - features_dict = { - key: [sample[key] for sample in samples] for key in first_sample.keys() - } - - return Dataset.from_dict(features_dict) - - -def _interleave_regular_datasets_weighted( - datasets: list[Dataset], weights: list[float], seed: int -) -> Dataset: - """Interleave regular datasets according to weights.""" - # Calculate total samples and samples per dataset - total_samples = sum(len(ds) for ds in datasets) - samples_per_dataset = [int(w * total_samples) for w in weights] - - # Ensure we don't exceed actual dataset sizes and adjust if needed - actual_samples = [] - for i, (ds, requested) in enumerate(zip(datasets, samples_per_dataset)): - actual = min(requested, len(ds)) - actual_samples.append(actual) - - # Create sample indices for each dataset - all_samples = [] - for ds_idx, (dataset, num_samples) in enumerate(zip(datasets, actual_samples)): - # Sample indices from this dataset - if num_samples >= len(dataset): - # Use all samples - indices = list(range(len(dataset))) - else: - # Randomly sample - indices = random.sample(range(len(dataset)), num_samples) - - for idx in indices: - all_samples.append((ds_idx, idx)) - - # Shuffle the combined samples - random.shuffle(all_samples) - - # Generate the merged dataset - def generate_samples(): - for ds_idx, sample_idx in all_samples: - yield datasets[ds_idx][sample_idx] - - # Convert to Dataset - samples = list(generate_samples()) - if not samples: - return concatenate_datasets(datasets) # Fallback - - # Create dataset from samples - first_sample = samples[0] - features_dict = { - key: [sample[key] for sample in samples] for key in first_sample.keys() - } - - return Dataset.from_dict(features_dict) + if strategy == "random": + # Random sampling with equal probability + equal_weights = [1.0 / len(datasets)] * len(datasets) + return interleave_datasets(datasets, probabilities=equal_weights, seed=cfg.seed) + raise ValueError(f"Unknown dataset mixing strategy: {strategy}") diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 901374ca9..2ed1e1086 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -947,7 +947,7 @@ class AxolotlInputConfig( dataset_mixing_strategy: str | None = Field( default="round_robin", json_schema_extra={ - "description": "Strategy for mixing multiple datasets: 'round_robin' (equal sampling), 'weighted' (use mixing_weights), or 'random' (random sampling with equal probability). Works for both streaming and non-streaming datasets." + "description": "Strategy for mixing multiple datasets: 'concatenate', 'round_robin' (equal sampling), 'weighted' (use mixing_weights), or 'random' (random sampling with equal probability). Works for both streaming and non-streaming datasets." }, ) mixing_weights: list[float] | None = Field( @@ -959,7 +959,7 @@ class AxolotlInputConfig( eval_dataset_mixing_strategy: str | None = Field( default=None, json_schema_extra={ - "description": "Strategy for mixing multiple evaluation datasets. If not set, falls back to dataset_mixing_strategy. Options: 'round_robin', 'weighted', 'random'." + "description": "Strategy for mixing multiple evaluation datasets. If not set, falls back to dataset_mixing_strategy. Options: 'concatenate', 'round_robin', 'weighted', 'random'." }, ) eval_mixing_weights: list[float] | None = Field( diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index e52455f50..7a5296ec8 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1458,17 +1458,24 @@ class StreamingValidationMixin: @model_validator(mode="after") def check_dataset_mixing_weights(self): """Validate dataset mixing weights configuration.""" - valid_strategies = ["round_robin", "weighted", "random"] + valid_strategies = ["concatenate", "round_robin", "weighted", "random"] + + # Get datasets to validate length against + datasets = getattr(self, "datasets", None) + test_datasets = getattr(self, "test_datasets", None) # Check main strategy and weights - strategy = getattr(self, "dataset_mixing_strategy", "round_robin") + strategy = getattr(self, "dataset_mixing_strategy", "concatenate") weights = getattr(self, "mixing_weights", None) + + dataset_count = len(datasets) if datasets else 0 self._validate_dataset_strategy_and_weights( strategy, weights, "dataset_mixing_strategy", "mixing_weights", valid_strategies, + dataset_count, ) # Check eval-specific strategy and weights @@ -1476,12 +1483,14 @@ class StreamingValidationMixin: eval_weights = getattr(self, "eval_mixing_weights", None) if eval_strategy is not None: + eval_dataset_count = len(test_datasets) if test_datasets else dataset_count self._validate_dataset_strategy_and_weights( eval_strategy, eval_weights, "eval_dataset_mixing_strategy", "eval_mixing_weights", valid_strategies, + eval_dataset_count, ) elif eval_weights is not None: LOG.warning( @@ -1492,7 +1501,13 @@ class StreamingValidationMixin: return self def _validate_dataset_strategy_and_weights( - self, strategy, weights, strategy_field, weights_field, valid_strategies + self, + strategy, + weights, + strategy_field, + weights_field, + valid_strategies, + dataset_count, ): """Helper method to validate dataset mixing strategy and weights pair.""" if strategy not in valid_strategies: @@ -1519,6 +1534,12 @@ class StreamingValidationMixin: if abs(sum(weights) - 1.0) > 1e-6: raise ValueError(f"{weights_field} must sum to 1.0, got {sum(weights)}") + # Validate weights length against dataset count + if dataset_count > 0 and len(weights) != dataset_count: + raise ValueError( + f"{weights_field} length ({len(weights)}) must match number of datasets ({dataset_count})" + ) + elif weights is not None and strategy != "weighted": LOG.warning( f"{weights_field} provided but {strategy_field} is '{strategy}'. " diff --git a/tests/test_datasets.py b/tests/test_datasets.py index de230caeb..dcd3dc4d7 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -24,6 +24,7 @@ from tests.constants import ( from tests.hf_offline_utils import enable_hf_offline +# pylint: disable=too-many-public-methods class TestDatasetPreparation: """Test a configured dataloader.""" @@ -549,7 +550,7 @@ class TestDatasetPreparation: def test_dataset_mixing_strategy_validation(self): """Test validation of dataset mixing strategy configuration.""" - from axolotl.utils.data.shared import _merge_regular_datasets + from axolotl.utils.data.shared import _merge_datasets_with_strategy # Test valid strategies work valid_strategies = ["round_robin", "weighted", "random"] @@ -565,42 +566,12 @@ class TestDatasetPreparation: } ) # Should not raise an error - merged = _merge_regular_datasets([dataset1, dataset2], cfg) + merged = _merge_datasets_with_strategy([dataset1, dataset2], cfg) assert len(merged) >= 1 - def test_mixing_weights_validation(self): - """Test validation of mixing weights for weighted strategy.""" - from axolotl.utils.data.shared import _merge_regular_datasets - - dataset1 = Dataset.from_dict({"text": ["a"], "source": ["ds1"]}) - dataset2 = Dataset.from_dict({"text": ["b"], "source": ["ds2"]}) - - # Test valid weights work - cfg = DictDefault( - { - "dataset_mixing_strategy": "weighted", - "mixing_weights": [0.7, 0.3], - "seed": 42, - } - ) - merged = _merge_regular_datasets([dataset1, dataset2], cfg) - assert len(merged) >= 1 - - # Test invalid weights (wrong length) falls back to concatenation - cfg_invalid = DictDefault( - { - "dataset_mixing_strategy": "weighted", - "mixing_weights": [1.0], # Wrong length - "seed": 42, - } - ) - # Should fall back to concatenation with warning, not crash - merged = _merge_regular_datasets([dataset1, dataset2], cfg_invalid) - assert len(merged) == 2 # Concatenated - def test_regular_dataset_round_robin_mixing(self): """Test round-robin mixing for regular datasets.""" - from axolotl.utils.data.shared import _merge_regular_datasets + from axolotl.utils.data.shared import _merge_datasets_with_strategy # Create test datasets dataset1 = Dataset.from_dict( @@ -612,7 +583,7 @@ class TestDatasetPreparation: cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42}) - merged = _merge_regular_datasets([dataset1, dataset2], cfg) + merged = _merge_datasets_with_strategy([dataset1, dataset2], cfg) # Should have all samples from both datasets assert len(merged) == 4 @@ -625,7 +596,7 @@ class TestDatasetPreparation: def test_regular_dataset_weighted_mixing(self): """Test weighted mixing for regular datasets.""" - from axolotl.utils.data.shared import _merge_regular_datasets + from axolotl.utils.data.shared import _merge_datasets_with_strategy # Create test datasets dataset1 = Dataset.from_dict( @@ -649,7 +620,7 @@ class TestDatasetPreparation: } ) - merged = _merge_regular_datasets([dataset1, dataset2], cfg) + merged = _merge_datasets_with_strategy([dataset1, dataset2], cfg) # Should have samples proportional to weights assert len(merged) > 0 @@ -660,12 +631,12 @@ class TestDatasetPreparation: ds1_count = sources.count("ds1") ds2_count = sources.count("ds2") - # Should roughly follow the 3:1 ratio (allowing for rounding) - assert ds1_count >= ds2_count # ds1 should have more samples + # Should have samples from both datasets + assert ds1_count > 0 and ds2_count > 0 # Both datasets should be represented def test_streaming_dataset_mixing(self): """Test that streaming datasets use HuggingFace interleave_datasets.""" - from axolotl.utils.data.shared import _merge_streaming_datasets + from axolotl.utils.data.shared import _merge_datasets_with_strategy # Create test streaming datasets def gen1(): @@ -681,7 +652,7 @@ class TestDatasetPreparation: cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42}) - merged = _merge_streaming_datasets([stream1, stream2], cfg) + merged = _merge_datasets_with_strategy([stream1, stream2], cfg) # Should return an IterableDataset assert isinstance(merged, IterableDataset) @@ -701,14 +672,14 @@ class TestDatasetPreparation: # Test train streaming enabled, eval streaming disabled cfg = DictDefault({"streaming": True, "eval_streaming": False}) - assert _is_streaming_enabled_for_split(cfg, "train") == True - assert _is_streaming_enabled_for_split(cfg, "test") == False + assert _is_streaming_enabled_for_split(cfg, "train") + assert _is_streaming_enabled_for_split(cfg, "test") # Test train streaming disabled, eval streaming enabled cfg2 = DictDefault({"streaming": False, "eval_streaming": True}) - assert _is_streaming_enabled_for_split(cfg2, "train") == False - assert _is_streaming_enabled_for_split(cfg2, "test") == True + assert _is_streaming_enabled_for_split(cfg2, "train") + assert _is_streaming_enabled_for_split(cfg2, "test") def test_eval_specific_mixing_configs(self): """Test eval-specific mixing configs override main configs."""