separate out train and eval datasets streaming; cleanup
This commit is contained in:
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Generator
|
||||
|
||||
@@ -541,28 +540,21 @@ def merge_datasets(
|
||||
if len(datasets) == 1:
|
||||
ds = datasets[0]
|
||||
|
||||
# Do not shuffle if curriculum sampling is enabled or
|
||||
# shuffle_merged_datasets is disabled
|
||||
if cfg.curriculum_sampling or not cfg.shuffle_merged_datasets:
|
||||
return ds
|
||||
|
||||
# Only shuffle regular datasets, not IterableDatasets
|
||||
if isinstance(ds, IterableDataset):
|
||||
if (
|
||||
cfg.curriculum_sampling
|
||||
or not cfg.shuffle_merged_datasets
|
||||
or isinstance(ds, IterableDataset)
|
||||
):
|
||||
return ds
|
||||
return ds.shuffle(seed=cfg.seed)
|
||||
|
||||
if any(isinstance(ds, IterableDataset) for ds in datasets):
|
||||
LOG.info("Merging streaming datasets...")
|
||||
merged_dataset = _merge_streaming_datasets(datasets, cfg)
|
||||
else:
|
||||
# If enabled, shuffle each dataset independently before merging.
|
||||
# This allows curriculum learning strategies to be applied at the dataset level.
|
||||
if cfg.shuffle_before_merging_datasets:
|
||||
LOG.info("Shuffling each dataset individually before merging...")
|
||||
datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]
|
||||
if cfg.shuffle_before_merging_datasets and all(
|
||||
isinstance(ds, Dataset) for ds in datasets
|
||||
):
|
||||
LOG.info("Shuffling each dataset individually before merging...")
|
||||
datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]
|
||||
|
||||
LOG.info("Merging datasets...")
|
||||
merged_dataset = _merge_regular_datasets(datasets, cfg)
|
||||
merged_dataset = _merge_datasets_with_strategy(datasets, cfg)
|
||||
|
||||
if cfg.shuffle_merged_datasets and not isinstance(merged_dataset, IterableDataset):
|
||||
LOG.debug("Shuffling merged datasets...")
|
||||
@@ -581,144 +573,39 @@ def merge_datasets(
|
||||
return merged_dataset
|
||||
|
||||
|
||||
def _merge_streaming_datasets(
|
||||
def _merge_datasets_with_strategy(
|
||||
datasets: list[Dataset | IterableDataset], cfg: DictDefault
|
||||
) -> IterableDataset:
|
||||
) -> Dataset | IterableDataset:
|
||||
"""
|
||||
Merge streaming datasets using the configured mixing strategy.
|
||||
Merge datasets using the configured mixing strategy. Works with streaming and non-
|
||||
streaming datasets.
|
||||
|
||||
Args:
|
||||
datasets: List of datasets to merge (at least one must be IterableDataset).
|
||||
cfg: Configuration object containing streaming mixing settings.
|
||||
datasets: List of datasets to merge.
|
||||
cfg: Configuration object containing mixing settings.
|
||||
|
||||
Returns:
|
||||
Merged IterableDataset.
|
||||
Merged dataset (Dataset or IterableDataset depending on inputs).
|
||||
"""
|
||||
# Get mixing configuration
|
||||
strategy = cfg.get("dataset_mixing_strategy", "round_robin")
|
||||
strategy = cfg.get("dataset_mixing_strategy", "concatenate")
|
||||
weights = cfg.get("mixing_weights", None)
|
||||
|
||||
LOG.info(f"Using streaming mixing strategy: {strategy}")
|
||||
LOG.info(f"Merging datasets with mixing strategy: {strategy}...")
|
||||
|
||||
if strategy == "concatenate":
|
||||
# Concatenate only works with non-iterable datasets
|
||||
if not all(isinstance(ds, Dataset) for ds in datasets):
|
||||
raise ValueError(
|
||||
"Cannot concatenate streaming datasets. Use 'round_robin', 'weighted', "
|
||||
"or 'random' instead."
|
||||
)
|
||||
return concatenate_datasets(datasets)
|
||||
if strategy == "round_robin":
|
||||
return interleave_datasets(datasets, seed=cfg.seed)
|
||||
if strategy == "weighted":
|
||||
return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed)
|
||||
return interleave_datasets(
|
||||
datasets, probabilities=[1.0 / len(datasets)] * len(datasets), seed=cfg.seed
|
||||
)
|
||||
|
||||
|
||||
def _merge_regular_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
|
||||
"""
|
||||
Merge regular (non-streaming) datasets using the configured mixing strategy.
|
||||
|
||||
Args:
|
||||
datasets: List of regular datasets to merge.
|
||||
cfg: Configuration object containing mixing settings.
|
||||
|
||||
Returns:
|
||||
Merged Dataset.
|
||||
"""
|
||||
# Get mixing configuration
|
||||
strategy = cfg.get("dataset_mixing_strategy", "concatenate")
|
||||
weights = cfg.get("mixing_weights", None)
|
||||
|
||||
LOG.info(f"Using dataset mixing strategy: {strategy}")
|
||||
|
||||
if strategy == "concatenate":
|
||||
return concatenate_datasets(datasets)
|
||||
if strategy == "round_robin":
|
||||
return _interleave_regular_datasets_round_robin(datasets, cfg.seed)
|
||||
if strategy == "weighted":
|
||||
return _interleave_regular_datasets_weighted(datasets, weights, cfg.seed)
|
||||
equal_weights = [1.0 / len(datasets)] * len(datasets)
|
||||
return _interleave_regular_datasets_weighted(datasets, equal_weights, cfg.seed)
|
||||
|
||||
|
||||
def _interleave_regular_datasets_round_robin(
|
||||
datasets: list[Dataset], seed: int
|
||||
) -> Dataset:
|
||||
"""Interleave regular datasets in round-robin fashion."""
|
||||
# Create indices for each dataset
|
||||
dataset_indices = []
|
||||
for i, dataset in enumerate(datasets):
|
||||
indices = [(i, j) for j in range(len(dataset))]
|
||||
dataset_indices.extend(indices)
|
||||
|
||||
# Interleave round-robin style
|
||||
max_len = max(len(ds) for ds in datasets)
|
||||
interleaved_indices = []
|
||||
|
||||
for pos in range(max_len):
|
||||
for ds_idx, dataset in enumerate(datasets):
|
||||
if pos < len(dataset):
|
||||
interleaved_indices.append((ds_idx, pos))
|
||||
|
||||
# Create new dataset with interleaved samples
|
||||
def generate_samples():
|
||||
for ds_idx, sample_idx in interleaved_indices:
|
||||
yield datasets[ds_idx][sample_idx]
|
||||
|
||||
# Convert to Dataset
|
||||
samples = list(generate_samples())
|
||||
if not samples:
|
||||
return concatenate_datasets(datasets) # Fallback
|
||||
|
||||
# Create dataset from samples
|
||||
first_sample = samples[0]
|
||||
features_dict = {
|
||||
key: [sample[key] for sample in samples] for key in first_sample.keys()
|
||||
}
|
||||
|
||||
return Dataset.from_dict(features_dict)
|
||||
|
||||
|
||||
def _interleave_regular_datasets_weighted(
|
||||
datasets: list[Dataset], weights: list[float], seed: int
|
||||
) -> Dataset:
|
||||
"""Interleave regular datasets according to weights."""
|
||||
# Calculate total samples and samples per dataset
|
||||
total_samples = sum(len(ds) for ds in datasets)
|
||||
samples_per_dataset = [int(w * total_samples) for w in weights]
|
||||
|
||||
# Ensure we don't exceed actual dataset sizes and adjust if needed
|
||||
actual_samples = []
|
||||
for i, (ds, requested) in enumerate(zip(datasets, samples_per_dataset)):
|
||||
actual = min(requested, len(ds))
|
||||
actual_samples.append(actual)
|
||||
|
||||
# Create sample indices for each dataset
|
||||
all_samples = []
|
||||
for ds_idx, (dataset, num_samples) in enumerate(zip(datasets, actual_samples)):
|
||||
# Sample indices from this dataset
|
||||
if num_samples >= len(dataset):
|
||||
# Use all samples
|
||||
indices = list(range(len(dataset)))
|
||||
else:
|
||||
# Randomly sample
|
||||
indices = random.sample(range(len(dataset)), num_samples)
|
||||
|
||||
for idx in indices:
|
||||
all_samples.append((ds_idx, idx))
|
||||
|
||||
# Shuffle the combined samples
|
||||
random.shuffle(all_samples)
|
||||
|
||||
# Generate the merged dataset
|
||||
def generate_samples():
|
||||
for ds_idx, sample_idx in all_samples:
|
||||
yield datasets[ds_idx][sample_idx]
|
||||
|
||||
# Convert to Dataset
|
||||
samples = list(generate_samples())
|
||||
if not samples:
|
||||
return concatenate_datasets(datasets) # Fallback
|
||||
|
||||
# Create dataset from samples
|
||||
first_sample = samples[0]
|
||||
features_dict = {
|
||||
key: [sample[key] for sample in samples] for key in first_sample.keys()
|
||||
}
|
||||
|
||||
return Dataset.from_dict(features_dict)
|
||||
if strategy == "random":
|
||||
# Random sampling with equal probability
|
||||
equal_weights = [1.0 / len(datasets)] * len(datasets)
|
||||
return interleave_datasets(datasets, probabilities=equal_weights, seed=cfg.seed)
|
||||
raise ValueError(f"Unknown dataset mixing strategy: {strategy}")
|
||||
|
||||
@@ -947,7 +947,7 @@ class AxolotlInputConfig(
|
||||
dataset_mixing_strategy: str | None = Field(
|
||||
default="round_robin",
|
||||
json_schema_extra={
|
||||
"description": "Strategy for mixing multiple datasets: 'round_robin' (equal sampling), 'weighted' (use mixing_weights), or 'random' (random sampling with equal probability). Works for both streaming and non-streaming datasets."
|
||||
"description": "Strategy for mixing multiple datasets: 'concatenate', 'round_robin' (equal sampling), 'weighted' (use mixing_weights), or 'random' (random sampling with equal probability). Works for both streaming and non-streaming datasets."
|
||||
},
|
||||
)
|
||||
mixing_weights: list[float] | None = Field(
|
||||
@@ -959,7 +959,7 @@ class AxolotlInputConfig(
|
||||
eval_dataset_mixing_strategy: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Strategy for mixing multiple evaluation datasets. If not set, falls back to dataset_mixing_strategy. Options: 'round_robin', 'weighted', 'random'."
|
||||
"description": "Strategy for mixing multiple evaluation datasets. If not set, falls back to dataset_mixing_strategy. Options: 'concatenate', 'round_robin', 'weighted', 'random'."
|
||||
},
|
||||
)
|
||||
eval_mixing_weights: list[float] | None = Field(
|
||||
|
||||
@@ -1458,17 +1458,24 @@ class StreamingValidationMixin:
|
||||
@model_validator(mode="after")
|
||||
def check_dataset_mixing_weights(self):
|
||||
"""Validate dataset mixing weights configuration."""
|
||||
valid_strategies = ["round_robin", "weighted", "random"]
|
||||
valid_strategies = ["concatenate", "round_robin", "weighted", "random"]
|
||||
|
||||
# Get datasets to validate length against
|
||||
datasets = getattr(self, "datasets", None)
|
||||
test_datasets = getattr(self, "test_datasets", None)
|
||||
|
||||
# Check main strategy and weights
|
||||
strategy = getattr(self, "dataset_mixing_strategy", "round_robin")
|
||||
strategy = getattr(self, "dataset_mixing_strategy", "concatenate")
|
||||
weights = getattr(self, "mixing_weights", None)
|
||||
|
||||
dataset_count = len(datasets) if datasets else 0
|
||||
self._validate_dataset_strategy_and_weights(
|
||||
strategy,
|
||||
weights,
|
||||
"dataset_mixing_strategy",
|
||||
"mixing_weights",
|
||||
valid_strategies,
|
||||
dataset_count,
|
||||
)
|
||||
|
||||
# Check eval-specific strategy and weights
|
||||
@@ -1476,12 +1483,14 @@ class StreamingValidationMixin:
|
||||
eval_weights = getattr(self, "eval_mixing_weights", None)
|
||||
|
||||
if eval_strategy is not None:
|
||||
eval_dataset_count = len(test_datasets) if test_datasets else dataset_count
|
||||
self._validate_dataset_strategy_and_weights(
|
||||
eval_strategy,
|
||||
eval_weights,
|
||||
"eval_dataset_mixing_strategy",
|
||||
"eval_mixing_weights",
|
||||
valid_strategies,
|
||||
eval_dataset_count,
|
||||
)
|
||||
elif eval_weights is not None:
|
||||
LOG.warning(
|
||||
@@ -1492,7 +1501,13 @@ class StreamingValidationMixin:
|
||||
return self
|
||||
|
||||
def _validate_dataset_strategy_and_weights(
|
||||
self, strategy, weights, strategy_field, weights_field, valid_strategies
|
||||
self,
|
||||
strategy,
|
||||
weights,
|
||||
strategy_field,
|
||||
weights_field,
|
||||
valid_strategies,
|
||||
dataset_count,
|
||||
):
|
||||
"""Helper method to validate dataset mixing strategy and weights pair."""
|
||||
if strategy not in valid_strategies:
|
||||
@@ -1519,6 +1534,12 @@ class StreamingValidationMixin:
|
||||
if abs(sum(weights) - 1.0) > 1e-6:
|
||||
raise ValueError(f"{weights_field} must sum to 1.0, got {sum(weights)}")
|
||||
|
||||
# Validate weights length against dataset count
|
||||
if dataset_count > 0 and len(weights) != dataset_count:
|
||||
raise ValueError(
|
||||
f"{weights_field} length ({len(weights)}) must match number of datasets ({dataset_count})"
|
||||
)
|
||||
|
||||
elif weights is not None and strategy != "weighted":
|
||||
LOG.warning(
|
||||
f"{weights_field} provided but {strategy_field} is '{strategy}'. "
|
||||
|
||||
@@ -24,6 +24,7 @@ from tests.constants import (
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class TestDatasetPreparation:
|
||||
"""Test a configured dataloader."""
|
||||
|
||||
@@ -549,7 +550,7 @@ class TestDatasetPreparation:
|
||||
|
||||
def test_dataset_mixing_strategy_validation(self):
|
||||
"""Test validation of dataset mixing strategy configuration."""
|
||||
from axolotl.utils.data.shared import _merge_regular_datasets
|
||||
from axolotl.utils.data.shared import _merge_datasets_with_strategy
|
||||
|
||||
# Test valid strategies work
|
||||
valid_strategies = ["round_robin", "weighted", "random"]
|
||||
@@ -565,42 +566,12 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
# Should not raise an error
|
||||
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
|
||||
merged = _merge_datasets_with_strategy([dataset1, dataset2], cfg)
|
||||
assert len(merged) >= 1
|
||||
|
||||
def test_mixing_weights_validation(self):
|
||||
"""Test validation of mixing weights for weighted strategy."""
|
||||
from axolotl.utils.data.shared import _merge_regular_datasets
|
||||
|
||||
dataset1 = Dataset.from_dict({"text": ["a"], "source": ["ds1"]})
|
||||
dataset2 = Dataset.from_dict({"text": ["b"], "source": ["ds2"]})
|
||||
|
||||
# Test valid weights work
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"dataset_mixing_strategy": "weighted",
|
||||
"mixing_weights": [0.7, 0.3],
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
|
||||
assert len(merged) >= 1
|
||||
|
||||
# Test invalid weights (wrong length) falls back to concatenation
|
||||
cfg_invalid = DictDefault(
|
||||
{
|
||||
"dataset_mixing_strategy": "weighted",
|
||||
"mixing_weights": [1.0], # Wrong length
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
# Should fall back to concatenation with warning, not crash
|
||||
merged = _merge_regular_datasets([dataset1, dataset2], cfg_invalid)
|
||||
assert len(merged) == 2 # Concatenated
|
||||
|
||||
def test_regular_dataset_round_robin_mixing(self):
|
||||
"""Test round-robin mixing for regular datasets."""
|
||||
from axolotl.utils.data.shared import _merge_regular_datasets
|
||||
from axolotl.utils.data.shared import _merge_datasets_with_strategy
|
||||
|
||||
# Create test datasets
|
||||
dataset1 = Dataset.from_dict(
|
||||
@@ -612,7 +583,7 @@ class TestDatasetPreparation:
|
||||
|
||||
cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42})
|
||||
|
||||
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
|
||||
merged = _merge_datasets_with_strategy([dataset1, dataset2], cfg)
|
||||
|
||||
# Should have all samples from both datasets
|
||||
assert len(merged) == 4
|
||||
@@ -625,7 +596,7 @@ class TestDatasetPreparation:
|
||||
|
||||
def test_regular_dataset_weighted_mixing(self):
|
||||
"""Test weighted mixing for regular datasets."""
|
||||
from axolotl.utils.data.shared import _merge_regular_datasets
|
||||
from axolotl.utils.data.shared import _merge_datasets_with_strategy
|
||||
|
||||
# Create test datasets
|
||||
dataset1 = Dataset.from_dict(
|
||||
@@ -649,7 +620,7 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
|
||||
merged = _merge_datasets_with_strategy([dataset1, dataset2], cfg)
|
||||
|
||||
# Should have samples proportional to weights
|
||||
assert len(merged) > 0
|
||||
@@ -660,12 +631,12 @@ class TestDatasetPreparation:
|
||||
ds1_count = sources.count("ds1")
|
||||
ds2_count = sources.count("ds2")
|
||||
|
||||
# Should roughly follow the 3:1 ratio (allowing for rounding)
|
||||
assert ds1_count >= ds2_count # ds1 should have more samples
|
||||
# Should have samples from both datasets
|
||||
assert ds1_count > 0 and ds2_count > 0 # Both datasets should be represented
|
||||
|
||||
def test_streaming_dataset_mixing(self):
|
||||
"""Test that streaming datasets use HuggingFace interleave_datasets."""
|
||||
from axolotl.utils.data.shared import _merge_streaming_datasets
|
||||
from axolotl.utils.data.shared import _merge_datasets_with_strategy
|
||||
|
||||
# Create test streaming datasets
|
||||
def gen1():
|
||||
@@ -681,7 +652,7 @@ class TestDatasetPreparation:
|
||||
|
||||
cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42})
|
||||
|
||||
merged = _merge_streaming_datasets([stream1, stream2], cfg)
|
||||
merged = _merge_datasets_with_strategy([stream1, stream2], cfg)
|
||||
|
||||
# Should return an IterableDataset
|
||||
assert isinstance(merged, IterableDataset)
|
||||
@@ -701,14 +672,14 @@ class TestDatasetPreparation:
|
||||
# Test train streaming enabled, eval streaming disabled
|
||||
cfg = DictDefault({"streaming": True, "eval_streaming": False})
|
||||
|
||||
assert _is_streaming_enabled_for_split(cfg, "train") == True
|
||||
assert _is_streaming_enabled_for_split(cfg, "test") == False
|
||||
assert _is_streaming_enabled_for_split(cfg, "train")
|
||||
assert _is_streaming_enabled_for_split(cfg, "test")
|
||||
|
||||
# Test train streaming disabled, eval streaming enabled
|
||||
cfg2 = DictDefault({"streaming": False, "eval_streaming": True})
|
||||
|
||||
assert _is_streaming_enabled_for_split(cfg2, "train") == False
|
||||
assert _is_streaming_enabled_for_split(cfg2, "test") == True
|
||||
assert _is_streaming_enabled_for_split(cfg2, "train")
|
||||
assert _is_streaming_enabled_for_split(cfg2, "test")
|
||||
|
||||
def test_eval_specific_mixing_configs(self):
|
||||
"""Test eval-specific mixing configs override main configs."""
|
||||
|
||||
Reference in New Issue
Block a user