This commit is contained in:
Dan Saunders
2025-08-20 13:46:29 +00:00
parent 2176962231
commit aa5a497a2c
5 changed files with 340 additions and 36 deletions

View File

@@ -75,13 +75,13 @@ def _get_streaming_config_for_split(
# Override with eval-specific configs if they exist # Override with eval-specific configs if they exist
streaming_cfg = DictDefault(cfg) streaming_cfg = DictDefault(cfg)
eval_strategy = cfg.get("eval_streaming_dataset_mixing_strategy") eval_strategy = cfg.get("eval_dataset_mixing_strategy")
eval_weights = cfg.get("eval_streaming_mixing_weights") eval_weights = cfg.get("eval_mixing_weights")
if eval_strategy is not None: if eval_strategy is not None:
streaming_cfg["streaming_dataset_mixing_strategy"] = eval_strategy streaming_cfg["dataset_mixing_strategy"] = eval_strategy
if eval_weights is not None: if eval_weights is not None:
streaming_cfg["streaming_mixing_weights"] = eval_weights streaming_cfg["mixing_weights"] = eval_weights
return streaming_cfg return streaming_cfg
@@ -392,10 +392,12 @@ def _load_raw_datasets(
if cfg.sample_packing: if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None) dataset, _ = process_datasets_for_packing(cfg, dataset, None)
dataset_hash = generate_dataset_hash_from_config( # Only save regular datasets to disk, not streaming datasets
cfg, datasets_configs, tokenizer.name_or_path if not isinstance(dataset, IterableDataset):
) dataset_hash = generate_dataset_hash_from_config(
save_preprocessed_dataset(cfg, dataset, dataset_hash, split) cfg, datasets_configs, tokenizer.name_or_path
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
return dataset, prompters return dataset, prompters

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import functools import functools
import os import os
from pathlib import Path from pathlib import Path
import random
from typing import TYPE_CHECKING, Any, Generator from typing import TYPE_CHECKING, Any, Generator
from datasets import ( from datasets import (
@@ -561,7 +562,7 @@ def merge_datasets(
datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets] datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]
LOG.info("Merging datasets...") LOG.info("Merging datasets...")
merged_dataset = concatenate_datasets(datasets) merged_dataset = _merge_regular_datasets(datasets, cfg)
if cfg.shuffle_merged_datasets and not isinstance(merged_dataset, IterableDataset): if cfg.shuffle_merged_datasets and not isinstance(merged_dataset, IterableDataset):
LOG.debug("Shuffling merged datasets...") LOG.debug("Shuffling merged datasets...")
@@ -583,7 +584,8 @@ def merge_datasets(
def _merge_streaming_datasets( def _merge_streaming_datasets(
datasets: list[Dataset | IterableDataset], cfg: DictDefault datasets: list[Dataset | IterableDataset], cfg: DictDefault
) -> IterableDataset: ) -> IterableDataset:
"""Merge streaming datasets using the configured mixing strategy. """
Merge streaming datasets using the configured mixing strategy.
Args: Args:
datasets: List of datasets to merge (at least one must be IterableDataset). datasets: List of datasets to merge (at least one must be IterableDataset).
@@ -593,8 +595,8 @@ def _merge_streaming_datasets(
Merged IterableDataset. Merged IterableDataset.
""" """
# Get mixing configuration # Get mixing configuration
strategy = cfg.get("streaming_dataset_mixing_strategy", "round_robin") strategy = cfg.get("dataset_mixing_strategy", "round_robin")
weights = cfg.get("streaming_mixing_weights", None) weights = cfg.get("mixing_weights", None)
LOG.info(f"Using streaming mixing strategy: {strategy}") LOG.info(f"Using streaming mixing strategy: {strategy}")
@@ -602,7 +604,121 @@ def _merge_streaming_datasets(
return interleave_datasets(datasets, seed=cfg.seed) return interleave_datasets(datasets, seed=cfg.seed)
if strategy == "weighted": if strategy == "weighted":
return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed) return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed)
return interleave_datasets( return interleave_datasets(
datasets, probabilities=[1.0 / len(datasets)] * len(datasets), seed=cfg.seed datasets, probabilities=[1.0 / len(datasets)] * len(datasets), seed=cfg.seed
) )
def _merge_regular_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
"""
Merge regular (non-streaming) datasets using the configured mixing strategy.
Args:
datasets: List of regular datasets to merge.
cfg: Configuration object containing mixing settings.
Returns:
Merged Dataset.
"""
# Get mixing configuration
strategy = cfg.get("dataset_mixing_strategy", "concatenate")
weights = cfg.get("mixing_weights", None)
LOG.info(f"Using dataset mixing strategy: {strategy}")
if strategy == "concatenate":
return concatenate_datasets(datasets)
if strategy == "round_robin":
return _interleave_regular_datasets_round_robin(datasets, cfg.seed)
if strategy == "weighted":
return _interleave_regular_datasets_weighted(datasets, weights, cfg.seed)
equal_weights = [1.0 / len(datasets)] * len(datasets)
return _interleave_regular_datasets_weighted(datasets, equal_weights, cfg.seed)
def _interleave_regular_datasets_round_robin(
datasets: list[Dataset], seed: int
) -> Dataset:
"""Interleave regular datasets in round-robin fashion."""
# Create indices for each dataset
dataset_indices = []
for i, dataset in enumerate(datasets):
indices = [(i, j) for j in range(len(dataset))]
dataset_indices.extend(indices)
# Interleave round-robin style
max_len = max(len(ds) for ds in datasets)
interleaved_indices = []
for pos in range(max_len):
for ds_idx, dataset in enumerate(datasets):
if pos < len(dataset):
interleaved_indices.append((ds_idx, pos))
# Create new dataset with interleaved samples
def generate_samples():
for ds_idx, sample_idx in interleaved_indices:
yield datasets[ds_idx][sample_idx]
# Convert to Dataset
samples = list(generate_samples())
if not samples:
return concatenate_datasets(datasets) # Fallback
# Create dataset from samples
first_sample = samples[0]
features_dict = {
key: [sample[key] for sample in samples] for key in first_sample.keys()
}
return Dataset.from_dict(features_dict)
def _interleave_regular_datasets_weighted(
datasets: list[Dataset], weights: list[float], seed: int
) -> Dataset:
"""Interleave regular datasets according to weights."""
# Calculate total samples and samples per dataset
total_samples = sum(len(ds) for ds in datasets)
samples_per_dataset = [int(w * total_samples) for w in weights]
# Ensure we don't exceed actual dataset sizes and adjust if needed
actual_samples = []
for i, (ds, requested) in enumerate(zip(datasets, samples_per_dataset)):
actual = min(requested, len(ds))
actual_samples.append(actual)
# Create sample indices for each dataset
all_samples = []
for ds_idx, (dataset, num_samples) in enumerate(zip(datasets, actual_samples)):
# Sample indices from this dataset
if num_samples >= len(dataset):
# Use all samples
indices = list(range(len(dataset)))
else:
# Randomly sample
indices = random.sample(range(len(dataset)), num_samples)
for idx in indices:
all_samples.append((ds_idx, idx))
# Shuffle the combined samples
random.shuffle(all_samples)
# Generate the merged dataset
def generate_samples():
for ds_idx, sample_idx in all_samples:
yield datasets[ds_idx][sample_idx]
# Convert to Dataset
samples = list(generate_samples())
if not samples:
return concatenate_datasets(datasets) # Fallback
# Create dataset from samples
first_sample = samples[0]
features_dict = {
key: [sample[key] for sample in samples] for key in first_sample.keys()
}
return Dataset.from_dict(features_dict)

View File

@@ -944,25 +944,25 @@ class AxolotlInputConfig(
"description": "Whether to use streaming datasets for evaluation datasets. If not set, falls back to the 'streaming' setting. Useful for streaming large training data while keeping smaller eval datasets in memory." "description": "Whether to use streaming datasets for evaluation datasets. If not set, falls back to the 'streaming' setting. Useful for streaming large training data while keeping smaller eval datasets in memory."
}, },
) )
streaming_dataset_mixing_strategy: str | None = Field( dataset_mixing_strategy: str | None = Field(
default="round_robin", default="round_robin",
json_schema_extra={ json_schema_extra={
"description": "Strategy for mixing multiple streaming datasets: 'round_robin' (equal sampling), 'weighted' (use streaming_mixing_weights), or 'random' (random sampling with equal probability)." "description": "Strategy for mixing multiple datasets: 'round_robin' (equal sampling), 'weighted' (use mixing_weights), or 'random' (random sampling with equal probability). Works for both streaming and non-streaming datasets."
}, },
) )
streaming_mixing_weights: list[float] | None = Field( mixing_weights: list[float] | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
"description": "Weights for weighted mixing strategy when using multiple streaming datasets. Must sum to 1.0 and have same length as datasets list. Only used when streaming_dataset_mixing_strategy='weighted'." "description": "Weights for weighted mixing strategy when using multiple datasets. Must sum to 1.0 and have same length as datasets list. Only used when dataset_mixing_strategy='weighted'."
}, },
) )
eval_streaming_dataset_mixing_strategy: str | None = Field( eval_dataset_mixing_strategy: str | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
"description": "Strategy for mixing multiple streaming evaluation datasets. If not set, falls back to streaming_dataset_mixing_strategy. Options: 'round_robin', 'weighted', 'random'." "description": "Strategy for mixing multiple evaluation datasets. If not set, falls back to dataset_mixing_strategy. Options: 'round_robin', 'weighted', 'random'."
}, },
) )
eval_streaming_mixing_weights: list[float] | None = Field( eval_mixing_weights: list[float] | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
"description": "Weights for weighted mixing strategy for evaluation datasets. Must sum to 1.0 and have same length as evaluation datasets list." "description": "Weights for weighted mixing strategy for evaluation datasets. Must sum to 1.0 and have same length as evaluation datasets list."

View File

@@ -1456,45 +1456,45 @@ class StreamingValidationMixin:
return self return self
@model_validator(mode="after") @model_validator(mode="after")
def check_streaming_mixing_weights(self): def check_dataset_mixing_weights(self):
"""Validate streaming_mixing_weights configuration.""" """Validate dataset mixing weights configuration."""
valid_strategies = ["round_robin", "weighted", "random"] valid_strategies = ["round_robin", "weighted", "random"]
# Check main strategy and weights # Check main strategy and weights
strategy = getattr(self, "streaming_dataset_mixing_strategy", "round_robin") strategy = getattr(self, "dataset_mixing_strategy", "round_robin")
weights = getattr(self, "streaming_mixing_weights", None) weights = getattr(self, "mixing_weights", None)
self._validate_streaming_strategy_and_weights( self._validate_dataset_strategy_and_weights(
strategy, strategy,
weights, weights,
"streaming_dataset_mixing_strategy", "dataset_mixing_strategy",
"streaming_mixing_weights", "mixing_weights",
valid_strategies, valid_strategies,
) )
# Check eval-specific strategy and weights # Check eval-specific strategy and weights
eval_strategy = getattr(self, "eval_streaming_dataset_mixing_strategy", None) eval_strategy = getattr(self, "eval_dataset_mixing_strategy", None)
eval_weights = getattr(self, "eval_streaming_mixing_weights", None) eval_weights = getattr(self, "eval_mixing_weights", None)
if eval_strategy is not None: if eval_strategy is not None:
self._validate_streaming_strategy_and_weights( self._validate_dataset_strategy_and_weights(
eval_strategy, eval_strategy,
eval_weights, eval_weights,
"eval_streaming_dataset_mixing_strategy", "eval_dataset_mixing_strategy",
"eval_streaming_mixing_weights", "eval_mixing_weights",
valid_strategies, valid_strategies,
) )
elif eval_weights is not None: elif eval_weights is not None:
LOG.warning( LOG.warning(
"eval_streaming_mixing_weights provided but eval_streaming_dataset_mixing_strategy is not set. " "eval_mixing_weights provided but eval_dataset_mixing_strategy is not set. "
"Weights will be ignored unless eval_streaming_dataset_mixing_strategy='weighted'." "Weights will be ignored unless eval_dataset_mixing_strategy='weighted'."
) )
return self return self
def _validate_streaming_strategy_and_weights( def _validate_dataset_strategy_and_weights(
self, strategy, weights, strategy_field, weights_field, valid_strategies self, strategy, weights, strategy_field, weights_field, valid_strategies
): ):
"""Helper method to validate strategy and weights pair.""" """Helper method to validate dataset mixing strategy and weights pair."""
if strategy not in valid_strategies: if strategy not in valid_strategies:
raise ValueError( raise ValueError(
f"{strategy_field} must be one of {valid_strategies}, " f"{strategy_field} must be one of {valid_strategies}, "

View File

@@ -546,3 +546,189 @@ class TestDatasetPreparation:
break break
assert sample_count == 2 assert sample_count == 2
def test_dataset_mixing_strategy_validation(self):
"""Test validation of dataset mixing strategy configuration."""
from axolotl.utils.data.shared import _merge_regular_datasets
# Test valid strategies work
valid_strategies = ["round_robin", "weighted", "random"]
dataset1 = Dataset.from_dict({"text": ["a"], "source": ["ds1"]})
dataset2 = Dataset.from_dict({"text": ["b"], "source": ["ds2"]})
for strategy in valid_strategies:
cfg = DictDefault(
{
"dataset_mixing_strategy": strategy,
"mixing_weights": [0.5, 0.5] if strategy == "weighted" else None,
"seed": 42,
}
)
# Should not raise an error
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
assert len(merged) >= 1
def test_mixing_weights_validation(self):
"""Test validation of mixing weights for weighted strategy."""
from axolotl.utils.data.shared import _merge_regular_datasets
dataset1 = Dataset.from_dict({"text": ["a"], "source": ["ds1"]})
dataset2 = Dataset.from_dict({"text": ["b"], "source": ["ds2"]})
# Test valid weights work
cfg = DictDefault(
{
"dataset_mixing_strategy": "weighted",
"mixing_weights": [0.7, 0.3],
"seed": 42,
}
)
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
assert len(merged) >= 1
# Test invalid weights (wrong length) falls back to concatenation
cfg_invalid = DictDefault(
{
"dataset_mixing_strategy": "weighted",
"mixing_weights": [1.0], # Wrong length
"seed": 42,
}
)
# Should fall back to concatenation with warning, not crash
merged = _merge_regular_datasets([dataset1, dataset2], cfg_invalid)
assert len(merged) == 2 # Concatenated
def test_regular_dataset_round_robin_mixing(self):
"""Test round-robin mixing for regular datasets."""
from axolotl.utils.data.shared import _merge_regular_datasets
# Create test datasets
dataset1 = Dataset.from_dict(
{"text": ["ds1_item1", "ds1_item2"], "source": ["ds1", "ds1"]}
)
dataset2 = Dataset.from_dict(
{"text": ["ds2_item1", "ds2_item2"], "source": ["ds2", "ds2"]}
)
cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42})
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
# Should have all samples from both datasets
assert len(merged) == 4
assert isinstance(merged, Dataset)
# Check that samples are interleaved (not just concatenated)
sources = [sample["source"] for sample in merged]
# Round-robin should alternate between datasets
assert sources != ["ds1", "ds1", "ds2", "ds2"] # Not concatenated
def test_regular_dataset_weighted_mixing(self):
"""Test weighted mixing for regular datasets."""
from axolotl.utils.data.shared import _merge_regular_datasets
# Create test datasets
dataset1 = Dataset.from_dict(
{
"text": ["ds1_item1", "ds1_item2", "ds1_item3", "ds1_item4"],
"source": ["ds1"] * 4,
}
)
dataset2 = Dataset.from_dict(
{
"text": ["ds2_item1", "ds2_item2", "ds2_item3", "ds2_item4"],
"source": ["ds2"] * 4,
}
)
cfg = DictDefault(
{
"dataset_mixing_strategy": "weighted",
"mixing_weights": [0.75, 0.25], # 3:1 ratio
"seed": 42,
}
)
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
# Should have samples proportional to weights
assert len(merged) > 0
assert isinstance(merged, Dataset)
# Count samples from each dataset
sources = [sample["source"] for sample in merged]
ds1_count = sources.count("ds1")
ds2_count = sources.count("ds2")
# Should roughly follow the 3:1 ratio (allowing for rounding)
assert ds1_count >= ds2_count # ds1 should have more samples
def test_streaming_dataset_mixing(self):
"""Test that streaming datasets use HuggingFace interleave_datasets."""
from axolotl.utils.data.shared import _merge_streaming_datasets
# Create test streaming datasets
def gen1():
yield {"text": "stream1_item1", "source": "stream1"}
yield {"text": "stream1_item2", "source": "stream1"}
def gen2():
yield {"text": "stream2_item1", "source": "stream2"}
yield {"text": "stream2_item2", "source": "stream2"}
stream1 = IterableDataset.from_generator(gen1)
stream2 = IterableDataset.from_generator(gen2)
cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42})
merged = _merge_streaming_datasets([stream1, stream2], cfg)
# Should return an IterableDataset
assert isinstance(merged, IterableDataset)
# Test that we can iterate and get samples
samples = list(merged.take(3))
assert len(samples) >= 2 # Should get at least 2 samples
# Should have samples from both datasets
sources = [sample["source"] for sample in samples]
assert len(set(sources)) >= 1 # At least one unique source
def test_eval_streaming_config(self):
"""Test eval_streaming separate from streaming config."""
from axolotl.utils.data.sft import _is_streaming_enabled_for_split
# Test train streaming enabled, eval streaming disabled
cfg = DictDefault({"streaming": True, "eval_streaming": False})
assert _is_streaming_enabled_for_split(cfg, "train") == True
assert _is_streaming_enabled_for_split(cfg, "test") == False
# Test train streaming disabled, eval streaming enabled
cfg2 = DictDefault({"streaming": False, "eval_streaming": True})
assert _is_streaming_enabled_for_split(cfg2, "train") == False
assert _is_streaming_enabled_for_split(cfg2, "test") == True
def test_eval_specific_mixing_configs(self):
"""Test eval-specific mixing configs override main configs."""
from axolotl.utils.data.sft import _get_streaming_config_for_split
cfg = DictDefault(
{
"dataset_mixing_strategy": "round_robin",
"mixing_weights": [0.5, 0.5],
"eval_dataset_mixing_strategy": "weighted",
"eval_mixing_weights": [0.8, 0.2],
}
)
# Train split should use main config
train_cfg = _get_streaming_config_for_split(cfg, "train")
assert train_cfg["dataset_mixing_strategy"] == "round_robin"
assert train_cfg["mixing_weights"] == [0.5, 0.5]
# Test split should use eval-specific config
test_cfg = _get_streaming_config_for_split(cfg, "test")
assert test_cfg["dataset_mixing_strategy"] == "weighted"
assert test_cfg["mixing_weights"] == [0.8, 0.2]