nits
This commit is contained in:
@@ -75,13 +75,13 @@ def _get_streaming_config_for_split(
|
|||||||
|
|
||||||
# Override with eval-specific configs if they exist
|
# Override with eval-specific configs if they exist
|
||||||
streaming_cfg = DictDefault(cfg)
|
streaming_cfg = DictDefault(cfg)
|
||||||
eval_strategy = cfg.get("eval_streaming_dataset_mixing_strategy")
|
eval_strategy = cfg.get("eval_dataset_mixing_strategy")
|
||||||
eval_weights = cfg.get("eval_streaming_mixing_weights")
|
eval_weights = cfg.get("eval_mixing_weights")
|
||||||
|
|
||||||
if eval_strategy is not None:
|
if eval_strategy is not None:
|
||||||
streaming_cfg["streaming_dataset_mixing_strategy"] = eval_strategy
|
streaming_cfg["dataset_mixing_strategy"] = eval_strategy
|
||||||
if eval_weights is not None:
|
if eval_weights is not None:
|
||||||
streaming_cfg["streaming_mixing_weights"] = eval_weights
|
streaming_cfg["mixing_weights"] = eval_weights
|
||||||
|
|
||||||
return streaming_cfg
|
return streaming_cfg
|
||||||
|
|
||||||
@@ -392,10 +392,12 @@ def _load_raw_datasets(
|
|||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||||
|
|
||||||
dataset_hash = generate_dataset_hash_from_config(
|
# Only save regular datasets to disk, not streaming datasets
|
||||||
cfg, datasets_configs, tokenizer.name_or_path
|
if not isinstance(dataset, IterableDataset):
|
||||||
)
|
dataset_hash = generate_dataset_hash_from_config(
|
||||||
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
|
cfg, datasets_configs, tokenizer.name_or_path
|
||||||
|
)
|
||||||
|
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
|
||||||
|
|
||||||
return dataset, prompters
|
return dataset, prompters
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import functools
|
import functools
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import random
|
||||||
from typing import TYPE_CHECKING, Any, Generator
|
from typing import TYPE_CHECKING, Any, Generator
|
||||||
|
|
||||||
from datasets import (
|
from datasets import (
|
||||||
@@ -561,7 +562,7 @@ def merge_datasets(
|
|||||||
datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]
|
datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]
|
||||||
|
|
||||||
LOG.info("Merging datasets...")
|
LOG.info("Merging datasets...")
|
||||||
merged_dataset = concatenate_datasets(datasets)
|
merged_dataset = _merge_regular_datasets(datasets, cfg)
|
||||||
|
|
||||||
if cfg.shuffle_merged_datasets and not isinstance(merged_dataset, IterableDataset):
|
if cfg.shuffle_merged_datasets and not isinstance(merged_dataset, IterableDataset):
|
||||||
LOG.debug("Shuffling merged datasets...")
|
LOG.debug("Shuffling merged datasets...")
|
||||||
@@ -583,7 +584,8 @@ def merge_datasets(
|
|||||||
def _merge_streaming_datasets(
|
def _merge_streaming_datasets(
|
||||||
datasets: list[Dataset | IterableDataset], cfg: DictDefault
|
datasets: list[Dataset | IterableDataset], cfg: DictDefault
|
||||||
) -> IterableDataset:
|
) -> IterableDataset:
|
||||||
"""Merge streaming datasets using the configured mixing strategy.
|
"""
|
||||||
|
Merge streaming datasets using the configured mixing strategy.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
datasets: List of datasets to merge (at least one must be IterableDataset).
|
datasets: List of datasets to merge (at least one must be IterableDataset).
|
||||||
@@ -593,8 +595,8 @@ def _merge_streaming_datasets(
|
|||||||
Merged IterableDataset.
|
Merged IterableDataset.
|
||||||
"""
|
"""
|
||||||
# Get mixing configuration
|
# Get mixing configuration
|
||||||
strategy = cfg.get("streaming_dataset_mixing_strategy", "round_robin")
|
strategy = cfg.get("dataset_mixing_strategy", "round_robin")
|
||||||
weights = cfg.get("streaming_mixing_weights", None)
|
weights = cfg.get("mixing_weights", None)
|
||||||
|
|
||||||
LOG.info(f"Using streaming mixing strategy: {strategy}")
|
LOG.info(f"Using streaming mixing strategy: {strategy}")
|
||||||
|
|
||||||
@@ -602,7 +604,121 @@ def _merge_streaming_datasets(
|
|||||||
return interleave_datasets(datasets, seed=cfg.seed)
|
return interleave_datasets(datasets, seed=cfg.seed)
|
||||||
if strategy == "weighted":
|
if strategy == "weighted":
|
||||||
return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed)
|
return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed)
|
||||||
|
|
||||||
return interleave_datasets(
|
return interleave_datasets(
|
||||||
datasets, probabilities=[1.0 / len(datasets)] * len(datasets), seed=cfg.seed
|
datasets, probabilities=[1.0 / len(datasets)] * len(datasets), seed=cfg.seed
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_regular_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
|
||||||
|
"""
|
||||||
|
Merge regular (non-streaming) datasets using the configured mixing strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
datasets: List of regular datasets to merge.
|
||||||
|
cfg: Configuration object containing mixing settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Merged Dataset.
|
||||||
|
"""
|
||||||
|
# Get mixing configuration
|
||||||
|
strategy = cfg.get("dataset_mixing_strategy", "concatenate")
|
||||||
|
weights = cfg.get("mixing_weights", None)
|
||||||
|
|
||||||
|
LOG.info(f"Using dataset mixing strategy: {strategy}")
|
||||||
|
|
||||||
|
if strategy == "concatenate":
|
||||||
|
return concatenate_datasets(datasets)
|
||||||
|
if strategy == "round_robin":
|
||||||
|
return _interleave_regular_datasets_round_robin(datasets, cfg.seed)
|
||||||
|
if strategy == "weighted":
|
||||||
|
return _interleave_regular_datasets_weighted(datasets, weights, cfg.seed)
|
||||||
|
equal_weights = [1.0 / len(datasets)] * len(datasets)
|
||||||
|
return _interleave_regular_datasets_weighted(datasets, equal_weights, cfg.seed)
|
||||||
|
|
||||||
|
|
||||||
|
def _interleave_regular_datasets_round_robin(
|
||||||
|
datasets: list[Dataset], seed: int
|
||||||
|
) -> Dataset:
|
||||||
|
"""Interleave regular datasets in round-robin fashion."""
|
||||||
|
# Create indices for each dataset
|
||||||
|
dataset_indices = []
|
||||||
|
for i, dataset in enumerate(datasets):
|
||||||
|
indices = [(i, j) for j in range(len(dataset))]
|
||||||
|
dataset_indices.extend(indices)
|
||||||
|
|
||||||
|
# Interleave round-robin style
|
||||||
|
max_len = max(len(ds) for ds in datasets)
|
||||||
|
interleaved_indices = []
|
||||||
|
|
||||||
|
for pos in range(max_len):
|
||||||
|
for ds_idx, dataset in enumerate(datasets):
|
||||||
|
if pos < len(dataset):
|
||||||
|
interleaved_indices.append((ds_idx, pos))
|
||||||
|
|
||||||
|
# Create new dataset with interleaved samples
|
||||||
|
def generate_samples():
|
||||||
|
for ds_idx, sample_idx in interleaved_indices:
|
||||||
|
yield datasets[ds_idx][sample_idx]
|
||||||
|
|
||||||
|
# Convert to Dataset
|
||||||
|
samples = list(generate_samples())
|
||||||
|
if not samples:
|
||||||
|
return concatenate_datasets(datasets) # Fallback
|
||||||
|
|
||||||
|
# Create dataset from samples
|
||||||
|
first_sample = samples[0]
|
||||||
|
features_dict = {
|
||||||
|
key: [sample[key] for sample in samples] for key in first_sample.keys()
|
||||||
|
}
|
||||||
|
|
||||||
|
return Dataset.from_dict(features_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def _interleave_regular_datasets_weighted(
|
||||||
|
datasets: list[Dataset], weights: list[float], seed: int
|
||||||
|
) -> Dataset:
|
||||||
|
"""Interleave regular datasets according to weights."""
|
||||||
|
# Calculate total samples and samples per dataset
|
||||||
|
total_samples = sum(len(ds) for ds in datasets)
|
||||||
|
samples_per_dataset = [int(w * total_samples) for w in weights]
|
||||||
|
|
||||||
|
# Ensure we don't exceed actual dataset sizes and adjust if needed
|
||||||
|
actual_samples = []
|
||||||
|
for i, (ds, requested) in enumerate(zip(datasets, samples_per_dataset)):
|
||||||
|
actual = min(requested, len(ds))
|
||||||
|
actual_samples.append(actual)
|
||||||
|
|
||||||
|
# Create sample indices for each dataset
|
||||||
|
all_samples = []
|
||||||
|
for ds_idx, (dataset, num_samples) in enumerate(zip(datasets, actual_samples)):
|
||||||
|
# Sample indices from this dataset
|
||||||
|
if num_samples >= len(dataset):
|
||||||
|
# Use all samples
|
||||||
|
indices = list(range(len(dataset)))
|
||||||
|
else:
|
||||||
|
# Randomly sample
|
||||||
|
indices = random.sample(range(len(dataset)), num_samples)
|
||||||
|
|
||||||
|
for idx in indices:
|
||||||
|
all_samples.append((ds_idx, idx))
|
||||||
|
|
||||||
|
# Shuffle the combined samples
|
||||||
|
random.shuffle(all_samples)
|
||||||
|
|
||||||
|
# Generate the merged dataset
|
||||||
|
def generate_samples():
|
||||||
|
for ds_idx, sample_idx in all_samples:
|
||||||
|
yield datasets[ds_idx][sample_idx]
|
||||||
|
|
||||||
|
# Convert to Dataset
|
||||||
|
samples = list(generate_samples())
|
||||||
|
if not samples:
|
||||||
|
return concatenate_datasets(datasets) # Fallback
|
||||||
|
|
||||||
|
# Create dataset from samples
|
||||||
|
first_sample = samples[0]
|
||||||
|
features_dict = {
|
||||||
|
key: [sample[key] for sample in samples] for key in first_sample.keys()
|
||||||
|
}
|
||||||
|
|
||||||
|
return Dataset.from_dict(features_dict)
|
||||||
|
|||||||
@@ -944,25 +944,25 @@ class AxolotlInputConfig(
|
|||||||
"description": "Whether to use streaming datasets for evaluation datasets. If not set, falls back to the 'streaming' setting. Useful for streaming large training data while keeping smaller eval datasets in memory."
|
"description": "Whether to use streaming datasets for evaluation datasets. If not set, falls back to the 'streaming' setting. Useful for streaming large training data while keeping smaller eval datasets in memory."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
streaming_dataset_mixing_strategy: str | None = Field(
|
dataset_mixing_strategy: str | None = Field(
|
||||||
default="round_robin",
|
default="round_robin",
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Strategy for mixing multiple streaming datasets: 'round_robin' (equal sampling), 'weighted' (use streaming_mixing_weights), or 'random' (random sampling with equal probability)."
|
"description": "Strategy for mixing multiple datasets: 'round_robin' (equal sampling), 'weighted' (use mixing_weights), or 'random' (random sampling with equal probability). Works for both streaming and non-streaming datasets."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
streaming_mixing_weights: list[float] | None = Field(
|
mixing_weights: list[float] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Weights for weighted mixing strategy when using multiple streaming datasets. Must sum to 1.0 and have same length as datasets list. Only used when streaming_dataset_mixing_strategy='weighted'."
|
"description": "Weights for weighted mixing strategy when using multiple datasets. Must sum to 1.0 and have same length as datasets list. Only used when dataset_mixing_strategy='weighted'."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
eval_streaming_dataset_mixing_strategy: str | None = Field(
|
eval_dataset_mixing_strategy: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Strategy for mixing multiple streaming evaluation datasets. If not set, falls back to streaming_dataset_mixing_strategy. Options: 'round_robin', 'weighted', 'random'."
|
"description": "Strategy for mixing multiple evaluation datasets. If not set, falls back to dataset_mixing_strategy. Options: 'round_robin', 'weighted', 'random'."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
eval_streaming_mixing_weights: list[float] | None = Field(
|
eval_mixing_weights: list[float] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Weights for weighted mixing strategy for evaluation datasets. Must sum to 1.0 and have same length as evaluation datasets list."
|
"description": "Weights for weighted mixing strategy for evaluation datasets. Must sum to 1.0 and have same length as evaluation datasets list."
|
||||||
|
|||||||
@@ -1456,45 +1456,45 @@ class StreamingValidationMixin:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_streaming_mixing_weights(self):
|
def check_dataset_mixing_weights(self):
|
||||||
"""Validate streaming_mixing_weights configuration."""
|
"""Validate dataset mixing weights configuration."""
|
||||||
valid_strategies = ["round_robin", "weighted", "random"]
|
valid_strategies = ["round_robin", "weighted", "random"]
|
||||||
|
|
||||||
# Check main strategy and weights
|
# Check main strategy and weights
|
||||||
strategy = getattr(self, "streaming_dataset_mixing_strategy", "round_robin")
|
strategy = getattr(self, "dataset_mixing_strategy", "round_robin")
|
||||||
weights = getattr(self, "streaming_mixing_weights", None)
|
weights = getattr(self, "mixing_weights", None)
|
||||||
self._validate_streaming_strategy_and_weights(
|
self._validate_dataset_strategy_and_weights(
|
||||||
strategy,
|
strategy,
|
||||||
weights,
|
weights,
|
||||||
"streaming_dataset_mixing_strategy",
|
"dataset_mixing_strategy",
|
||||||
"streaming_mixing_weights",
|
"mixing_weights",
|
||||||
valid_strategies,
|
valid_strategies,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check eval-specific strategy and weights
|
# Check eval-specific strategy and weights
|
||||||
eval_strategy = getattr(self, "eval_streaming_dataset_mixing_strategy", None)
|
eval_strategy = getattr(self, "eval_dataset_mixing_strategy", None)
|
||||||
eval_weights = getattr(self, "eval_streaming_mixing_weights", None)
|
eval_weights = getattr(self, "eval_mixing_weights", None)
|
||||||
|
|
||||||
if eval_strategy is not None:
|
if eval_strategy is not None:
|
||||||
self._validate_streaming_strategy_and_weights(
|
self._validate_dataset_strategy_and_weights(
|
||||||
eval_strategy,
|
eval_strategy,
|
||||||
eval_weights,
|
eval_weights,
|
||||||
"eval_streaming_dataset_mixing_strategy",
|
"eval_dataset_mixing_strategy",
|
||||||
"eval_streaming_mixing_weights",
|
"eval_mixing_weights",
|
||||||
valid_strategies,
|
valid_strategies,
|
||||||
)
|
)
|
||||||
elif eval_weights is not None:
|
elif eval_weights is not None:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"eval_streaming_mixing_weights provided but eval_streaming_dataset_mixing_strategy is not set. "
|
"eval_mixing_weights provided but eval_dataset_mixing_strategy is not set. "
|
||||||
"Weights will be ignored unless eval_streaming_dataset_mixing_strategy='weighted'."
|
"Weights will be ignored unless eval_dataset_mixing_strategy='weighted'."
|
||||||
)
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _validate_streaming_strategy_and_weights(
|
def _validate_dataset_strategy_and_weights(
|
||||||
self, strategy, weights, strategy_field, weights_field, valid_strategies
|
self, strategy, weights, strategy_field, weights_field, valid_strategies
|
||||||
):
|
):
|
||||||
"""Helper method to validate strategy and weights pair."""
|
"""Helper method to validate dataset mixing strategy and weights pair."""
|
||||||
if strategy not in valid_strategies:
|
if strategy not in valid_strategies:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{strategy_field} must be one of {valid_strategies}, "
|
f"{strategy_field} must be one of {valid_strategies}, "
|
||||||
|
|||||||
@@ -546,3 +546,189 @@ class TestDatasetPreparation:
|
|||||||
break
|
break
|
||||||
|
|
||||||
assert sample_count == 2
|
assert sample_count == 2
|
||||||
|
|
||||||
|
def test_dataset_mixing_strategy_validation(self):
|
||||||
|
"""Test validation of dataset mixing strategy configuration."""
|
||||||
|
from axolotl.utils.data.shared import _merge_regular_datasets
|
||||||
|
|
||||||
|
# Test valid strategies work
|
||||||
|
valid_strategies = ["round_robin", "weighted", "random"]
|
||||||
|
dataset1 = Dataset.from_dict({"text": ["a"], "source": ["ds1"]})
|
||||||
|
dataset2 = Dataset.from_dict({"text": ["b"], "source": ["ds2"]})
|
||||||
|
|
||||||
|
for strategy in valid_strategies:
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"dataset_mixing_strategy": strategy,
|
||||||
|
"mixing_weights": [0.5, 0.5] if strategy == "weighted" else None,
|
||||||
|
"seed": 42,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Should not raise an error
|
||||||
|
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
|
||||||
|
assert len(merged) >= 1
|
||||||
|
|
||||||
|
def test_mixing_weights_validation(self):
|
||||||
|
"""Test validation of mixing weights for weighted strategy."""
|
||||||
|
from axolotl.utils.data.shared import _merge_regular_datasets
|
||||||
|
|
||||||
|
dataset1 = Dataset.from_dict({"text": ["a"], "source": ["ds1"]})
|
||||||
|
dataset2 = Dataset.from_dict({"text": ["b"], "source": ["ds2"]})
|
||||||
|
|
||||||
|
# Test valid weights work
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"dataset_mixing_strategy": "weighted",
|
||||||
|
"mixing_weights": [0.7, 0.3],
|
||||||
|
"seed": 42,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
|
||||||
|
assert len(merged) >= 1
|
||||||
|
|
||||||
|
# Test invalid weights (wrong length) falls back to concatenation
|
||||||
|
cfg_invalid = DictDefault(
|
||||||
|
{
|
||||||
|
"dataset_mixing_strategy": "weighted",
|
||||||
|
"mixing_weights": [1.0], # Wrong length
|
||||||
|
"seed": 42,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Should fall back to concatenation with warning, not crash
|
||||||
|
merged = _merge_regular_datasets([dataset1, dataset2], cfg_invalid)
|
||||||
|
assert len(merged) == 2 # Concatenated
|
||||||
|
|
||||||
|
def test_regular_dataset_round_robin_mixing(self):
|
||||||
|
"""Test round-robin mixing for regular datasets."""
|
||||||
|
from axolotl.utils.data.shared import _merge_regular_datasets
|
||||||
|
|
||||||
|
# Create test datasets
|
||||||
|
dataset1 = Dataset.from_dict(
|
||||||
|
{"text": ["ds1_item1", "ds1_item2"], "source": ["ds1", "ds1"]}
|
||||||
|
)
|
||||||
|
dataset2 = Dataset.from_dict(
|
||||||
|
{"text": ["ds2_item1", "ds2_item2"], "source": ["ds2", "ds2"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42})
|
||||||
|
|
||||||
|
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
|
||||||
|
|
||||||
|
# Should have all samples from both datasets
|
||||||
|
assert len(merged) == 4
|
||||||
|
assert isinstance(merged, Dataset)
|
||||||
|
|
||||||
|
# Check that samples are interleaved (not just concatenated)
|
||||||
|
sources = [sample["source"] for sample in merged]
|
||||||
|
# Round-robin should alternate between datasets
|
||||||
|
assert sources != ["ds1", "ds1", "ds2", "ds2"] # Not concatenated
|
||||||
|
|
||||||
|
def test_regular_dataset_weighted_mixing(self):
|
||||||
|
"""Test weighted mixing for regular datasets."""
|
||||||
|
from axolotl.utils.data.shared import _merge_regular_datasets
|
||||||
|
|
||||||
|
# Create test datasets
|
||||||
|
dataset1 = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"text": ["ds1_item1", "ds1_item2", "ds1_item3", "ds1_item4"],
|
||||||
|
"source": ["ds1"] * 4,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
dataset2 = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"text": ["ds2_item1", "ds2_item2", "ds2_item3", "ds2_item4"],
|
||||||
|
"source": ["ds2"] * 4,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"dataset_mixing_strategy": "weighted",
|
||||||
|
"mixing_weights": [0.75, 0.25], # 3:1 ratio
|
||||||
|
"seed": 42,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
|
||||||
|
|
||||||
|
# Should have samples proportional to weights
|
||||||
|
assert len(merged) > 0
|
||||||
|
assert isinstance(merged, Dataset)
|
||||||
|
|
||||||
|
# Count samples from each dataset
|
||||||
|
sources = [sample["source"] for sample in merged]
|
||||||
|
ds1_count = sources.count("ds1")
|
||||||
|
ds2_count = sources.count("ds2")
|
||||||
|
|
||||||
|
# Should roughly follow the 3:1 ratio (allowing for rounding)
|
||||||
|
assert ds1_count >= ds2_count # ds1 should have more samples
|
||||||
|
|
||||||
|
def test_streaming_dataset_mixing(self):
|
||||||
|
"""Test that streaming datasets use HuggingFace interleave_datasets."""
|
||||||
|
from axolotl.utils.data.shared import _merge_streaming_datasets
|
||||||
|
|
||||||
|
# Create test streaming datasets
|
||||||
|
def gen1():
|
||||||
|
yield {"text": "stream1_item1", "source": "stream1"}
|
||||||
|
yield {"text": "stream1_item2", "source": "stream1"}
|
||||||
|
|
||||||
|
def gen2():
|
||||||
|
yield {"text": "stream2_item1", "source": "stream2"}
|
||||||
|
yield {"text": "stream2_item2", "source": "stream2"}
|
||||||
|
|
||||||
|
stream1 = IterableDataset.from_generator(gen1)
|
||||||
|
stream2 = IterableDataset.from_generator(gen2)
|
||||||
|
|
||||||
|
cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42})
|
||||||
|
|
||||||
|
merged = _merge_streaming_datasets([stream1, stream2], cfg)
|
||||||
|
|
||||||
|
# Should return an IterableDataset
|
||||||
|
assert isinstance(merged, IterableDataset)
|
||||||
|
|
||||||
|
# Test that we can iterate and get samples
|
||||||
|
samples = list(merged.take(3))
|
||||||
|
assert len(samples) >= 2 # Should get at least 2 samples
|
||||||
|
|
||||||
|
# Should have samples from both datasets
|
||||||
|
sources = [sample["source"] for sample in samples]
|
||||||
|
assert len(set(sources)) >= 1 # At least one unique source
|
||||||
|
|
||||||
|
def test_eval_streaming_config(self):
|
||||||
|
"""Test eval_streaming separate from streaming config."""
|
||||||
|
from axolotl.utils.data.sft import _is_streaming_enabled_for_split
|
||||||
|
|
||||||
|
# Test train streaming enabled, eval streaming disabled
|
||||||
|
cfg = DictDefault({"streaming": True, "eval_streaming": False})
|
||||||
|
|
||||||
|
assert _is_streaming_enabled_for_split(cfg, "train") == True
|
||||||
|
assert _is_streaming_enabled_for_split(cfg, "test") == False
|
||||||
|
|
||||||
|
# Test train streaming disabled, eval streaming enabled
|
||||||
|
cfg2 = DictDefault({"streaming": False, "eval_streaming": True})
|
||||||
|
|
||||||
|
assert _is_streaming_enabled_for_split(cfg2, "train") == False
|
||||||
|
assert _is_streaming_enabled_for_split(cfg2, "test") == True
|
||||||
|
|
||||||
|
def test_eval_specific_mixing_configs(self):
|
||||||
|
"""Test eval-specific mixing configs override main configs."""
|
||||||
|
from axolotl.utils.data.sft import _get_streaming_config_for_split
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"dataset_mixing_strategy": "round_robin",
|
||||||
|
"mixing_weights": [0.5, 0.5],
|
||||||
|
"eval_dataset_mixing_strategy": "weighted",
|
||||||
|
"eval_mixing_weights": [0.8, 0.2],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train split should use main config
|
||||||
|
train_cfg = _get_streaming_config_for_split(cfg, "train")
|
||||||
|
assert train_cfg["dataset_mixing_strategy"] == "round_robin"
|
||||||
|
assert train_cfg["mixing_weights"] == [0.5, 0.5]
|
||||||
|
|
||||||
|
# Test split should use eval-specific config
|
||||||
|
test_cfg = _get_streaming_config_for_split(cfg, "test")
|
||||||
|
assert test_cfg["dataset_mixing_strategy"] == "weighted"
|
||||||
|
assert test_cfg["mixing_weights"] == [0.8, 0.2]
|
||||||
|
|||||||
Reference in New Issue
Block a user