separate out train and eval datasets streaming; cleanup

This commit is contained in:
Dan Saunders
2025-08-20 15:08:31 +00:00
parent 067158e24a
commit 0843dc678a
4 changed files with 74 additions and 195 deletions

View File

@@ -24,6 +24,7 @@ from tests.constants import (
from tests.hf_offline_utils import enable_hf_offline
# pylint: disable=too-many-public-methods
class TestDatasetPreparation:
"""Test a configured dataloader."""
@@ -549,7 +550,7 @@ class TestDatasetPreparation:
def test_dataset_mixing_strategy_validation(self):
"""Test validation of dataset mixing strategy configuration."""
from axolotl.utils.data.shared import _merge_regular_datasets
from axolotl.utils.data.shared import _merge_datasets_with_strategy
# Test valid strategies work
valid_strategies = ["round_robin", "weighted", "random"]
@@ -565,42 +566,12 @@ class TestDatasetPreparation:
}
)
# Should not raise an error
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
merged = _merge_datasets_with_strategy([dataset1, dataset2], cfg)
assert len(merged) >= 1
def test_mixing_weights_validation(self):
"""Test validation of mixing weights for weighted strategy."""
from axolotl.utils.data.shared import _merge_regular_datasets
dataset1 = Dataset.from_dict({"text": ["a"], "source": ["ds1"]})
dataset2 = Dataset.from_dict({"text": ["b"], "source": ["ds2"]})
# Test valid weights work
cfg = DictDefault(
{
"dataset_mixing_strategy": "weighted",
"mixing_weights": [0.7, 0.3],
"seed": 42,
}
)
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
assert len(merged) >= 1
# Test invalid weights (wrong length) falls back to concatenation
cfg_invalid = DictDefault(
{
"dataset_mixing_strategy": "weighted",
"mixing_weights": [1.0], # Wrong length
"seed": 42,
}
)
# Should fall back to concatenation with warning, not crash
merged = _merge_regular_datasets([dataset1, dataset2], cfg_invalid)
assert len(merged) == 2 # Concatenated
def test_regular_dataset_round_robin_mixing(self):
"""Test round-robin mixing for regular datasets."""
from axolotl.utils.data.shared import _merge_regular_datasets
from axolotl.utils.data.shared import _merge_datasets_with_strategy
# Create test datasets
dataset1 = Dataset.from_dict(
@@ -612,7 +583,7 @@ class TestDatasetPreparation:
cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42})
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
merged = _merge_datasets_with_strategy([dataset1, dataset2], cfg)
# Should have all samples from both datasets
assert len(merged) == 4
@@ -625,7 +596,7 @@ class TestDatasetPreparation:
def test_regular_dataset_weighted_mixing(self):
"""Test weighted mixing for regular datasets."""
from axolotl.utils.data.shared import _merge_regular_datasets
from axolotl.utils.data.shared import _merge_datasets_with_strategy
# Create test datasets
dataset1 = Dataset.from_dict(
@@ -649,7 +620,7 @@ class TestDatasetPreparation:
}
)
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
merged = _merge_datasets_with_strategy([dataset1, dataset2], cfg)
# Should have samples proportional to weights
assert len(merged) > 0
@@ -660,12 +631,12 @@ class TestDatasetPreparation:
ds1_count = sources.count("ds1")
ds2_count = sources.count("ds2")
# Should roughly follow the 3:1 ratio (allowing for rounding)
assert ds1_count >= ds2_count # ds1 should have more samples
# Should have samples from both datasets
assert ds1_count > 0 and ds2_count > 0 # Both datasets should be represented
def test_streaming_dataset_mixing(self):
"""Test that streaming datasets use HuggingFace interleave_datasets."""
from axolotl.utils.data.shared import _merge_streaming_datasets
from axolotl.utils.data.shared import _merge_datasets_with_strategy
# Create test streaming datasets
def gen1():
@@ -681,7 +652,7 @@ class TestDatasetPreparation:
cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42})
merged = _merge_streaming_datasets([stream1, stream2], cfg)
merged = _merge_datasets_with_strategy([stream1, stream2], cfg)
# Should return an IterableDataset
assert isinstance(merged, IterableDataset)
@@ -701,14 +672,14 @@ class TestDatasetPreparation:
# Test train streaming enabled, eval streaming disabled
cfg = DictDefault({"streaming": True, "eval_streaming": False})
assert _is_streaming_enabled_for_split(cfg, "train") == True
assert _is_streaming_enabled_for_split(cfg, "test") == False
assert _is_streaming_enabled_for_split(cfg, "train")
assert _is_streaming_enabled_for_split(cfg, "test")
# Test train streaming disabled, eval streaming enabled
cfg2 = DictDefault({"streaming": False, "eval_streaming": True})
assert _is_streaming_enabled_for_split(cfg2, "train") == False
assert _is_streaming_enabled_for_split(cfg2, "test") == True
assert _is_streaming_enabled_for_split(cfg2, "train")
assert _is_streaming_enabled_for_split(cfg2, "test")
def test_eval_specific_mixing_configs(self):
"""Test eval-specific mixing configs override main configs."""