From aa5a497a2ca6623ec09b22938e1d94e227c7cbac Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 20 Aug 2025 13:46:29 +0000 Subject: [PATCH] nits --- src/axolotl/utils/data/sft.py | 18 ++- src/axolotl/utils/data/shared.py | 126 +++++++++++++++- src/axolotl/utils/schemas/config.py | 14 +- src/axolotl/utils/schemas/validation.py | 32 ++-- tests/test_datasets.py | 186 ++++++++++++++++++++++++ 5 files changed, 340 insertions(+), 36 deletions(-) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 8fb4ea63d..aeb55a73b 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -75,13 +75,13 @@ def _get_streaming_config_for_split( # Override with eval-specific configs if they exist streaming_cfg = DictDefault(cfg) - eval_strategy = cfg.get("eval_streaming_dataset_mixing_strategy") - eval_weights = cfg.get("eval_streaming_mixing_weights") + eval_strategy = cfg.get("eval_dataset_mixing_strategy") + eval_weights = cfg.get("eval_mixing_weights") if eval_strategy is not None: - streaming_cfg["streaming_dataset_mixing_strategy"] = eval_strategy + streaming_cfg["dataset_mixing_strategy"] = eval_strategy if eval_weights is not None: - streaming_cfg["streaming_mixing_weights"] = eval_weights + streaming_cfg["mixing_weights"] = eval_weights return streaming_cfg @@ -392,10 +392,12 @@ def _load_raw_datasets( if cfg.sample_packing: dataset, _ = process_datasets_for_packing(cfg, dataset, None) - dataset_hash = generate_dataset_hash_from_config( - cfg, datasets_configs, tokenizer.name_or_path - ) - save_preprocessed_dataset(cfg, dataset, dataset_hash, split) + # Only save regular datasets to disk, not streaming datasets + if not isinstance(dataset, IterableDataset): + dataset_hash = generate_dataset_hash_from_config( + cfg, datasets_configs, tokenizer.name_or_path + ) + save_preprocessed_dataset(cfg, dataset, dataset_hash, split) return dataset, prompters diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index 9da658675..28ee7c7fd 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -5,6 +5,7 @@ from __future__ import annotations import functools import os from pathlib import Path +import random from typing import TYPE_CHECKING, Any, Generator from datasets import ( @@ -561,7 +562,7 @@ def merge_datasets( datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets] LOG.info("Merging datasets...") - merged_dataset = concatenate_datasets(datasets) + merged_dataset = _merge_regular_datasets(datasets, cfg) if cfg.shuffle_merged_datasets and not isinstance(merged_dataset, IterableDataset): LOG.debug("Shuffling merged datasets...") @@ -583,7 +584,8 @@ def merge_datasets( def _merge_streaming_datasets( datasets: list[Dataset | IterableDataset], cfg: DictDefault ) -> IterableDataset: - """Merge streaming datasets using the configured mixing strategy. + """ + Merge streaming datasets using the configured mixing strategy. Args: datasets: List of datasets to merge (at least one must be IterableDataset). @@ -593,8 +595,8 @@ def _merge_streaming_datasets( Merged IterableDataset. """ # Get mixing configuration - strategy = cfg.get("streaming_dataset_mixing_strategy", "round_robin") - weights = cfg.get("streaming_mixing_weights", None) + strategy = cfg.get("dataset_mixing_strategy", "round_robin") + weights = cfg.get("mixing_weights", None) LOG.info(f"Using streaming mixing strategy: {strategy}") @@ -602,7 +604,121 @@ def _merge_streaming_datasets( return interleave_datasets(datasets, seed=cfg.seed) if strategy == "weighted": return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed) - return interleave_datasets( datasets, probabilities=[1.0 / len(datasets)] * len(datasets), seed=cfg.seed ) + + +def _merge_regular_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset: + """ + Merge regular (non-streaming) datasets using the configured mixing strategy. + + Args: + datasets: List of regular datasets to merge. + cfg: Configuration object containing mixing settings. + + Returns: + Merged Dataset. + """ + # Get mixing configuration + strategy = cfg.get("dataset_mixing_strategy", "concatenate") + weights = cfg.get("mixing_weights", None) + + LOG.info(f"Using dataset mixing strategy: {strategy}") + + if strategy == "concatenate": + return concatenate_datasets(datasets) + if strategy == "round_robin": + return _interleave_regular_datasets_round_robin(datasets, cfg.seed) + if strategy == "weighted": + return _interleave_regular_datasets_weighted(datasets, weights, cfg.seed) + equal_weights = [1.0 / len(datasets)] * len(datasets) + return _interleave_regular_datasets_weighted(datasets, equal_weights, cfg.seed) + + +def _interleave_regular_datasets_round_robin( + datasets: list[Dataset], seed: int +) -> Dataset: + """Interleave regular datasets in round-robin fashion.""" + # Create indices for each dataset + dataset_indices = [] + for i, dataset in enumerate(datasets): + indices = [(i, j) for j in range(len(dataset))] + dataset_indices.extend(indices) + + # Interleave round-robin style + max_len = max(len(ds) for ds in datasets) + interleaved_indices = [] + + for pos in range(max_len): + for ds_idx, dataset in enumerate(datasets): + if pos < len(dataset): + interleaved_indices.append((ds_idx, pos)) + + # Create new dataset with interleaved samples + def generate_samples(): + for ds_idx, sample_idx in interleaved_indices: + yield datasets[ds_idx][sample_idx] + + # Convert to Dataset + samples = list(generate_samples()) + if not samples: + return concatenate_datasets(datasets) # Fallback + + # Create dataset from samples + first_sample = samples[0] + features_dict = { + key: [sample[key] for sample in samples] for key in first_sample.keys() + } + + return Dataset.from_dict(features_dict) + + +def _interleave_regular_datasets_weighted( + datasets: list[Dataset], weights: list[float], seed: int +) -> Dataset: + """Interleave regular datasets according to weights.""" + # Calculate total samples and samples per dataset + total_samples = sum(len(ds) for ds in datasets) + samples_per_dataset = [int(w * total_samples) for w in weights] + + # Ensure we don't exceed actual dataset sizes and adjust if needed + actual_samples = [] + for i, (ds, requested) in enumerate(zip(datasets, samples_per_dataset)): + actual = min(requested, len(ds)) + actual_samples.append(actual) + + # Create sample indices for each dataset + all_samples = [] + for ds_idx, (dataset, num_samples) in enumerate(zip(datasets, actual_samples)): + # Sample indices from this dataset + if num_samples >= len(dataset): + # Use all samples + indices = list(range(len(dataset))) + else: + # Randomly sample + indices = random.sample(range(len(dataset)), num_samples) + + for idx in indices: + all_samples.append((ds_idx, idx)) + + # Shuffle the combined samples + random.shuffle(all_samples) + + # Generate the merged dataset + def generate_samples(): + for ds_idx, sample_idx in all_samples: + yield datasets[ds_idx][sample_idx] + + # Convert to Dataset + samples = list(generate_samples()) + if not samples: + return concatenate_datasets(datasets) # Fallback + + # Create dataset from samples + first_sample = samples[0] + features_dict = { + key: [sample[key] for sample in samples] for key in first_sample.keys() + } + + return Dataset.from_dict(features_dict) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index c335e6c73..901374ca9 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -944,25 +944,25 @@ class AxolotlInputConfig( "description": "Whether to use streaming datasets for evaluation datasets. If not set, falls back to the 'streaming' setting. Useful for streaming large training data while keeping smaller eval datasets in memory." }, ) - streaming_dataset_mixing_strategy: str | None = Field( + dataset_mixing_strategy: str | None = Field( default="round_robin", json_schema_extra={ - "description": "Strategy for mixing multiple streaming datasets: 'round_robin' (equal sampling), 'weighted' (use streaming_mixing_weights), or 'random' (random sampling with equal probability)." + "description": "Strategy for mixing multiple datasets: 'round_robin' (equal sampling), 'weighted' (use mixing_weights), or 'random' (random sampling with equal probability). Works for both streaming and non-streaming datasets." }, ) - streaming_mixing_weights: list[float] | None = Field( + mixing_weights: list[float] | None = Field( default=None, json_schema_extra={ - "description": "Weights for weighted mixing strategy when using multiple streaming datasets. Must sum to 1.0 and have same length as datasets list. Only used when streaming_dataset_mixing_strategy='weighted'." + "description": "Weights for weighted mixing strategy when using multiple datasets. Must sum to 1.0 and have same length as datasets list. Only used when dataset_mixing_strategy='weighted'." }, ) - eval_streaming_dataset_mixing_strategy: str | None = Field( + eval_dataset_mixing_strategy: str | None = Field( default=None, json_schema_extra={ - "description": "Strategy for mixing multiple streaming evaluation datasets. If not set, falls back to streaming_dataset_mixing_strategy. Options: 'round_robin', 'weighted', 'random'." + "description": "Strategy for mixing multiple evaluation datasets. If not set, falls back to dataset_mixing_strategy. Options: 'round_robin', 'weighted', 'random'." }, ) - eval_streaming_mixing_weights: list[float] | None = Field( + eval_mixing_weights: list[float] | None = Field( default=None, json_schema_extra={ "description": "Weights for weighted mixing strategy for evaluation datasets. Must sum to 1.0 and have same length as evaluation datasets list." diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 5ec85709a..e52455f50 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1456,45 +1456,45 @@ class StreamingValidationMixin: return self @model_validator(mode="after") - def check_streaming_mixing_weights(self): - """Validate streaming_mixing_weights configuration.""" + def check_dataset_mixing_weights(self): + """Validate dataset mixing weights configuration.""" valid_strategies = ["round_robin", "weighted", "random"] # Check main strategy and weights - strategy = getattr(self, "streaming_dataset_mixing_strategy", "round_robin") - weights = getattr(self, "streaming_mixing_weights", None) - self._validate_streaming_strategy_and_weights( + strategy = getattr(self, "dataset_mixing_strategy", "round_robin") + weights = getattr(self, "mixing_weights", None) + self._validate_dataset_strategy_and_weights( strategy, weights, - "streaming_dataset_mixing_strategy", - "streaming_mixing_weights", + "dataset_mixing_strategy", + "mixing_weights", valid_strategies, ) # Check eval-specific strategy and weights - eval_strategy = getattr(self, "eval_streaming_dataset_mixing_strategy", None) - eval_weights = getattr(self, "eval_streaming_mixing_weights", None) + eval_strategy = getattr(self, "eval_dataset_mixing_strategy", None) + eval_weights = getattr(self, "eval_mixing_weights", None) if eval_strategy is not None: - self._validate_streaming_strategy_and_weights( + self._validate_dataset_strategy_and_weights( eval_strategy, eval_weights, - "eval_streaming_dataset_mixing_strategy", - "eval_streaming_mixing_weights", + "eval_dataset_mixing_strategy", + "eval_mixing_weights", valid_strategies, ) elif eval_weights is not None: LOG.warning( - "eval_streaming_mixing_weights provided but eval_streaming_dataset_mixing_strategy is not set. " - "Weights will be ignored unless eval_streaming_dataset_mixing_strategy='weighted'." + "eval_mixing_weights provided but eval_dataset_mixing_strategy is not set. " + "Weights will be ignored unless eval_dataset_mixing_strategy='weighted'." ) return self - def _validate_streaming_strategy_and_weights( + def _validate_dataset_strategy_and_weights( self, strategy, weights, strategy_field, weights_field, valid_strategies ): - """Helper method to validate strategy and weights pair.""" + """Helper method to validate dataset mixing strategy and weights pair.""" if strategy not in valid_strategies: raise ValueError( f"{strategy_field} must be one of {valid_strategies}, " diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 7d371bb41..de230caeb 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -546,3 +546,189 @@ class TestDatasetPreparation: break assert sample_count == 2 + + def test_dataset_mixing_strategy_validation(self): + """Test validation of dataset mixing strategy configuration.""" + from axolotl.utils.data.shared import _merge_regular_datasets + + # Test valid strategies work + valid_strategies = ["round_robin", "weighted", "random"] + dataset1 = Dataset.from_dict({"text": ["a"], "source": ["ds1"]}) + dataset2 = Dataset.from_dict({"text": ["b"], "source": ["ds2"]}) + + for strategy in valid_strategies: + cfg = DictDefault( + { + "dataset_mixing_strategy": strategy, + "mixing_weights": [0.5, 0.5] if strategy == "weighted" else None, + "seed": 42, + } + ) + # Should not raise an error + merged = _merge_regular_datasets([dataset1, dataset2], cfg) + assert len(merged) >= 1 + + def test_mixing_weights_validation(self): + """Test validation of mixing weights for weighted strategy.""" + from axolotl.utils.data.shared import _merge_regular_datasets + + dataset1 = Dataset.from_dict({"text": ["a"], "source": ["ds1"]}) + dataset2 = Dataset.from_dict({"text": ["b"], "source": ["ds2"]}) + + # Test valid weights work + cfg = DictDefault( + { + "dataset_mixing_strategy": "weighted", + "mixing_weights": [0.7, 0.3], + "seed": 42, + } + ) + merged = _merge_regular_datasets([dataset1, dataset2], cfg) + assert len(merged) >= 1 + + # Test invalid weights (wrong length) falls back to concatenation + cfg_invalid = DictDefault( + { + "dataset_mixing_strategy": "weighted", + "mixing_weights": [1.0], # Wrong length + "seed": 42, + } + ) + # Should fall back to concatenation with warning, not crash + merged = _merge_regular_datasets([dataset1, dataset2], cfg_invalid) + assert len(merged) == 2 # Concatenated + + def test_regular_dataset_round_robin_mixing(self): + """Test round-robin mixing for regular datasets.""" + from axolotl.utils.data.shared import _merge_regular_datasets + + # Create test datasets + dataset1 = Dataset.from_dict( + {"text": ["ds1_item1", "ds1_item2"], "source": ["ds1", "ds1"]} + ) + dataset2 = Dataset.from_dict( + {"text": ["ds2_item1", "ds2_item2"], "source": ["ds2", "ds2"]} + ) + + cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42}) + + merged = _merge_regular_datasets([dataset1, dataset2], cfg) + + # Should have all samples from both datasets + assert len(merged) == 4 + assert isinstance(merged, Dataset) + + # Check that samples are interleaved (not just concatenated) + sources = [sample["source"] for sample in merged] + # Round-robin should alternate between datasets + assert sources != ["ds1", "ds1", "ds2", "ds2"] # Not concatenated + + def test_regular_dataset_weighted_mixing(self): + """Test weighted mixing for regular datasets.""" + from axolotl.utils.data.shared import _merge_regular_datasets + + # Create test datasets + dataset1 = Dataset.from_dict( + { + "text": ["ds1_item1", "ds1_item2", "ds1_item3", "ds1_item4"], + "source": ["ds1"] * 4, + } + ) + dataset2 = Dataset.from_dict( + { + "text": ["ds2_item1", "ds2_item2", "ds2_item3", "ds2_item4"], + "source": ["ds2"] * 4, + } + ) + + cfg = DictDefault( + { + "dataset_mixing_strategy": "weighted", + "mixing_weights": [0.75, 0.25], # 3:1 ratio + "seed": 42, + } + ) + + merged = _merge_regular_datasets([dataset1, dataset2], cfg) + + # Should have samples proportional to weights + assert len(merged) > 0 + assert isinstance(merged, Dataset) + + # Count samples from each dataset + sources = [sample["source"] for sample in merged] + ds1_count = sources.count("ds1") + ds2_count = sources.count("ds2") + + # Should roughly follow the 3:1 ratio (allowing for rounding) + assert ds1_count >= ds2_count # ds1 should have more samples + + def test_streaming_dataset_mixing(self): + """Test that streaming datasets use HuggingFace interleave_datasets.""" + from axolotl.utils.data.shared import _merge_streaming_datasets + + # Create test streaming datasets + def gen1(): + yield {"text": "stream1_item1", "source": "stream1"} + yield {"text": "stream1_item2", "source": "stream1"} + + def gen2(): + yield {"text": "stream2_item1", "source": "stream2"} + yield {"text": "stream2_item2", "source": "stream2"} + + stream1 = IterableDataset.from_generator(gen1) + stream2 = IterableDataset.from_generator(gen2) + + cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42}) + + merged = _merge_streaming_datasets([stream1, stream2], cfg) + + # Should return an IterableDataset + assert isinstance(merged, IterableDataset) + + # Test that we can iterate and get samples + samples = list(merged.take(3)) + assert len(samples) >= 2 # Should get at least 2 samples + + # Should have samples from both datasets + sources = [sample["source"] for sample in samples] + assert len(set(sources)) >= 1 # At least one unique source + + def test_eval_streaming_config(self): + """Test eval_streaming separate from streaming config.""" + from axolotl.utils.data.sft import _is_streaming_enabled_for_split + + # Test train streaming enabled, eval streaming disabled + cfg = DictDefault({"streaming": True, "eval_streaming": False}) + + assert _is_streaming_enabled_for_split(cfg, "train") == True + assert _is_streaming_enabled_for_split(cfg, "test") == False + + # Test train streaming disabled, eval streaming enabled + cfg2 = DictDefault({"streaming": False, "eval_streaming": True}) + + assert _is_streaming_enabled_for_split(cfg2, "train") == False + assert _is_streaming_enabled_for_split(cfg2, "test") == True + + def test_eval_specific_mixing_configs(self): + """Test eval-specific mixing configs override main configs.""" + from axolotl.utils.data.sft import _get_streaming_config_for_split + + cfg = DictDefault( + { + "dataset_mixing_strategy": "round_robin", + "mixing_weights": [0.5, 0.5], + "eval_dataset_mixing_strategy": "weighted", + "eval_mixing_weights": [0.8, 0.2], + } + ) + + # Train split should use main config + train_cfg = _get_streaming_config_for_split(cfg, "train") + assert train_cfg["dataset_mixing_strategy"] == "round_robin" + assert train_cfg["mixing_weights"] == [0.5, 0.5] + + # Test split should use eval-specific config + test_cfg = _get_streaming_config_for_split(cfg, "test") + assert test_cfg["dataset_mixing_strategy"] == "weighted" + assert test_cfg["mixing_weights"] == [0.8, 0.2]