This commit is contained in:
Dan Saunders
2025-08-20 13:46:29 +00:00
parent 2176962231
commit aa5a497a2c
5 changed files with 340 additions and 36 deletions

View File

@@ -75,13 +75,13 @@ def _get_streaming_config_for_split(
# Override with eval-specific configs if they exist
streaming_cfg = DictDefault(cfg)
eval_strategy = cfg.get("eval_streaming_dataset_mixing_strategy")
eval_weights = cfg.get("eval_streaming_mixing_weights")
eval_strategy = cfg.get("eval_dataset_mixing_strategy")
eval_weights = cfg.get("eval_mixing_weights")
if eval_strategy is not None:
streaming_cfg["streaming_dataset_mixing_strategy"] = eval_strategy
streaming_cfg["dataset_mixing_strategy"] = eval_strategy
if eval_weights is not None:
streaming_cfg["streaming_mixing_weights"] = eval_weights
streaming_cfg["mixing_weights"] = eval_weights
return streaming_cfg
@@ -392,10 +392,12 @@ def _load_raw_datasets(
if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
# Only save regular datasets to disk, not streaming datasets
if not isinstance(dataset, IterableDataset):
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
return dataset, prompters

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import functools
import os
from pathlib import Path
import random
from typing import TYPE_CHECKING, Any, Generator
from datasets import (
@@ -561,7 +562,7 @@ def merge_datasets(
datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]
LOG.info("Merging datasets...")
merged_dataset = concatenate_datasets(datasets)
merged_dataset = _merge_regular_datasets(datasets, cfg)
if cfg.shuffle_merged_datasets and not isinstance(merged_dataset, IterableDataset):
LOG.debug("Shuffling merged datasets...")
@@ -583,7 +584,8 @@ def merge_datasets(
def _merge_streaming_datasets(
datasets: list[Dataset | IterableDataset], cfg: DictDefault
) -> IterableDataset:
"""Merge streaming datasets using the configured mixing strategy.
"""
Merge streaming datasets using the configured mixing strategy.
Args:
datasets: List of datasets to merge (at least one must be IterableDataset).
@@ -593,8 +595,8 @@ def _merge_streaming_datasets(
Merged IterableDataset.
"""
# Get mixing configuration
strategy = cfg.get("streaming_dataset_mixing_strategy", "round_robin")
weights = cfg.get("streaming_mixing_weights", None)
strategy = cfg.get("dataset_mixing_strategy", "round_robin")
weights = cfg.get("mixing_weights", None)
LOG.info(f"Using streaming mixing strategy: {strategy}")
@@ -602,7 +604,121 @@ def _merge_streaming_datasets(
return interleave_datasets(datasets, seed=cfg.seed)
if strategy == "weighted":
return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed)
return interleave_datasets(
datasets, probabilities=[1.0 / len(datasets)] * len(datasets), seed=cfg.seed
)
def _merge_regular_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
"""
Merge regular (non-streaming) datasets using the configured mixing strategy.
Args:
datasets: List of regular datasets to merge.
cfg: Configuration object containing mixing settings.
Returns:
Merged Dataset.
"""
# Get mixing configuration
strategy = cfg.get("dataset_mixing_strategy", "concatenate")
weights = cfg.get("mixing_weights", None)
LOG.info(f"Using dataset mixing strategy: {strategy}")
if strategy == "concatenate":
return concatenate_datasets(datasets)
if strategy == "round_robin":
return _interleave_regular_datasets_round_robin(datasets, cfg.seed)
if strategy == "weighted":
return _interleave_regular_datasets_weighted(datasets, weights, cfg.seed)
equal_weights = [1.0 / len(datasets)] * len(datasets)
return _interleave_regular_datasets_weighted(datasets, equal_weights, cfg.seed)
def _interleave_regular_datasets_round_robin(
datasets: list[Dataset], seed: int
) -> Dataset:
"""Interleave regular datasets in round-robin fashion."""
# Create indices for each dataset
dataset_indices = []
for i, dataset in enumerate(datasets):
indices = [(i, j) for j in range(len(dataset))]
dataset_indices.extend(indices)
# Interleave round-robin style
max_len = max(len(ds) for ds in datasets)
interleaved_indices = []
for pos in range(max_len):
for ds_idx, dataset in enumerate(datasets):
if pos < len(dataset):
interleaved_indices.append((ds_idx, pos))
# Create new dataset with interleaved samples
def generate_samples():
for ds_idx, sample_idx in interleaved_indices:
yield datasets[ds_idx][sample_idx]
# Convert to Dataset
samples = list(generate_samples())
if not samples:
return concatenate_datasets(datasets) # Fallback
# Create dataset from samples
first_sample = samples[0]
features_dict = {
key: [sample[key] for sample in samples] for key in first_sample.keys()
}
return Dataset.from_dict(features_dict)
def _interleave_regular_datasets_weighted(
datasets: list[Dataset], weights: list[float], seed: int
) -> Dataset:
"""Interleave regular datasets according to weights."""
# Calculate total samples and samples per dataset
total_samples = sum(len(ds) for ds in datasets)
samples_per_dataset = [int(w * total_samples) for w in weights]
# Ensure we don't exceed actual dataset sizes and adjust if needed
actual_samples = []
for i, (ds, requested) in enumerate(zip(datasets, samples_per_dataset)):
actual = min(requested, len(ds))
actual_samples.append(actual)
# Create sample indices for each dataset
all_samples = []
for ds_idx, (dataset, num_samples) in enumerate(zip(datasets, actual_samples)):
# Sample indices from this dataset
if num_samples >= len(dataset):
# Use all samples
indices = list(range(len(dataset)))
else:
# Randomly sample
indices = random.sample(range(len(dataset)), num_samples)
for idx in indices:
all_samples.append((ds_idx, idx))
# Shuffle the combined samples
random.shuffle(all_samples)
# Generate the merged dataset
def generate_samples():
for ds_idx, sample_idx in all_samples:
yield datasets[ds_idx][sample_idx]
# Convert to Dataset
samples = list(generate_samples())
if not samples:
return concatenate_datasets(datasets) # Fallback
# Create dataset from samples
first_sample = samples[0]
features_dict = {
key: [sample[key] for sample in samples] for key in first_sample.keys()
}
return Dataset.from_dict(features_dict)

View File

@@ -944,25 +944,25 @@ class AxolotlInputConfig(
"description": "Whether to use streaming datasets for evaluation datasets. If not set, falls back to the 'streaming' setting. Useful for streaming large training data while keeping smaller eval datasets in memory."
},
)
streaming_dataset_mixing_strategy: str | None = Field(
dataset_mixing_strategy: str | None = Field(
default="round_robin",
json_schema_extra={
"description": "Strategy for mixing multiple streaming datasets: 'round_robin' (equal sampling), 'weighted' (use streaming_mixing_weights), or 'random' (random sampling with equal probability)."
"description": "Strategy for mixing multiple datasets: 'round_robin' (equal sampling), 'weighted' (use mixing_weights), or 'random' (random sampling with equal probability). Works for both streaming and non-streaming datasets."
},
)
streaming_mixing_weights: list[float] | None = Field(
mixing_weights: list[float] | None = Field(
default=None,
json_schema_extra={
"description": "Weights for weighted mixing strategy when using multiple streaming datasets. Must sum to 1.0 and have same length as datasets list. Only used when streaming_dataset_mixing_strategy='weighted'."
"description": "Weights for weighted mixing strategy when using multiple datasets. Must sum to 1.0 and have same length as datasets list. Only used when dataset_mixing_strategy='weighted'."
},
)
eval_streaming_dataset_mixing_strategy: str | None = Field(
eval_dataset_mixing_strategy: str | None = Field(
default=None,
json_schema_extra={
"description": "Strategy for mixing multiple streaming evaluation datasets. If not set, falls back to streaming_dataset_mixing_strategy. Options: 'round_robin', 'weighted', 'random'."
"description": "Strategy for mixing multiple evaluation datasets. If not set, falls back to dataset_mixing_strategy. Options: 'round_robin', 'weighted', 'random'."
},
)
eval_streaming_mixing_weights: list[float] | None = Field(
eval_mixing_weights: list[float] | None = Field(
default=None,
json_schema_extra={
"description": "Weights for weighted mixing strategy for evaluation datasets. Must sum to 1.0 and have same length as evaluation datasets list."

View File

@@ -1456,45 +1456,45 @@ class StreamingValidationMixin:
return self
@model_validator(mode="after")
def check_streaming_mixing_weights(self):
"""Validate streaming_mixing_weights configuration."""
def check_dataset_mixing_weights(self):
"""Validate dataset mixing weights configuration."""
valid_strategies = ["round_robin", "weighted", "random"]
# Check main strategy and weights
strategy = getattr(self, "streaming_dataset_mixing_strategy", "round_robin")
weights = getattr(self, "streaming_mixing_weights", None)
self._validate_streaming_strategy_and_weights(
strategy = getattr(self, "dataset_mixing_strategy", "round_robin")
weights = getattr(self, "mixing_weights", None)
self._validate_dataset_strategy_and_weights(
strategy,
weights,
"streaming_dataset_mixing_strategy",
"streaming_mixing_weights",
"dataset_mixing_strategy",
"mixing_weights",
valid_strategies,
)
# Check eval-specific strategy and weights
eval_strategy = getattr(self, "eval_streaming_dataset_mixing_strategy", None)
eval_weights = getattr(self, "eval_streaming_mixing_weights", None)
eval_strategy = getattr(self, "eval_dataset_mixing_strategy", None)
eval_weights = getattr(self, "eval_mixing_weights", None)
if eval_strategy is not None:
self._validate_streaming_strategy_and_weights(
self._validate_dataset_strategy_and_weights(
eval_strategy,
eval_weights,
"eval_streaming_dataset_mixing_strategy",
"eval_streaming_mixing_weights",
"eval_dataset_mixing_strategy",
"eval_mixing_weights",
valid_strategies,
)
elif eval_weights is not None:
LOG.warning(
"eval_streaming_mixing_weights provided but eval_streaming_dataset_mixing_strategy is not set. "
"Weights will be ignored unless eval_streaming_dataset_mixing_strategy='weighted'."
"eval_mixing_weights provided but eval_dataset_mixing_strategy is not set. "
"Weights will be ignored unless eval_dataset_mixing_strategy='weighted'."
)
return self
def _validate_streaming_strategy_and_weights(
def _validate_dataset_strategy_and_weights(
self, strategy, weights, strategy_field, weights_field, valid_strategies
):
"""Helper method to validate strategy and weights pair."""
"""Helper method to validate dataset mixing strategy and weights pair."""
if strategy not in valid_strategies:
raise ValueError(
f"{strategy_field} must be one of {valid_strategies}, "