nits
This commit is contained in:
@@ -546,3 +546,189 @@ class TestDatasetPreparation:
|
||||
break
|
||||
|
||||
assert sample_count == 2
|
||||
|
||||
def test_dataset_mixing_strategy_validation(self):
|
||||
"""Test validation of dataset mixing strategy configuration."""
|
||||
from axolotl.utils.data.shared import _merge_regular_datasets
|
||||
|
||||
# Test valid strategies work
|
||||
valid_strategies = ["round_robin", "weighted", "random"]
|
||||
dataset1 = Dataset.from_dict({"text": ["a"], "source": ["ds1"]})
|
||||
dataset2 = Dataset.from_dict({"text": ["b"], "source": ["ds2"]})
|
||||
|
||||
for strategy in valid_strategies:
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"dataset_mixing_strategy": strategy,
|
||||
"mixing_weights": [0.5, 0.5] if strategy == "weighted" else None,
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
# Should not raise an error
|
||||
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
|
||||
assert len(merged) >= 1
|
||||
|
||||
def test_mixing_weights_validation(self):
|
||||
"""Test validation of mixing weights for weighted strategy."""
|
||||
from axolotl.utils.data.shared import _merge_regular_datasets
|
||||
|
||||
dataset1 = Dataset.from_dict({"text": ["a"], "source": ["ds1"]})
|
||||
dataset2 = Dataset.from_dict({"text": ["b"], "source": ["ds2"]})
|
||||
|
||||
# Test valid weights work
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"dataset_mixing_strategy": "weighted",
|
||||
"mixing_weights": [0.7, 0.3],
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
|
||||
assert len(merged) >= 1
|
||||
|
||||
# Test invalid weights (wrong length) falls back to concatenation
|
||||
cfg_invalid = DictDefault(
|
||||
{
|
||||
"dataset_mixing_strategy": "weighted",
|
||||
"mixing_weights": [1.0], # Wrong length
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
# Should fall back to concatenation with warning, not crash
|
||||
merged = _merge_regular_datasets([dataset1, dataset2], cfg_invalid)
|
||||
assert len(merged) == 2 # Concatenated
|
||||
|
||||
def test_regular_dataset_round_robin_mixing(self):
|
||||
"""Test round-robin mixing for regular datasets."""
|
||||
from axolotl.utils.data.shared import _merge_regular_datasets
|
||||
|
||||
# Create test datasets
|
||||
dataset1 = Dataset.from_dict(
|
||||
{"text": ["ds1_item1", "ds1_item2"], "source": ["ds1", "ds1"]}
|
||||
)
|
||||
dataset2 = Dataset.from_dict(
|
||||
{"text": ["ds2_item1", "ds2_item2"], "source": ["ds2", "ds2"]}
|
||||
)
|
||||
|
||||
cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42})
|
||||
|
||||
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
|
||||
|
||||
# Should have all samples from both datasets
|
||||
assert len(merged) == 4
|
||||
assert isinstance(merged, Dataset)
|
||||
|
||||
# Check that samples are interleaved (not just concatenated)
|
||||
sources = [sample["source"] for sample in merged]
|
||||
# Round-robin should alternate between datasets
|
||||
assert sources != ["ds1", "ds1", "ds2", "ds2"] # Not concatenated
|
||||
|
||||
def test_regular_dataset_weighted_mixing(self):
|
||||
"""Test weighted mixing for regular datasets."""
|
||||
from axolotl.utils.data.shared import _merge_regular_datasets
|
||||
|
||||
# Create test datasets
|
||||
dataset1 = Dataset.from_dict(
|
||||
{
|
||||
"text": ["ds1_item1", "ds1_item2", "ds1_item3", "ds1_item4"],
|
||||
"source": ["ds1"] * 4,
|
||||
}
|
||||
)
|
||||
dataset2 = Dataset.from_dict(
|
||||
{
|
||||
"text": ["ds2_item1", "ds2_item2", "ds2_item3", "ds2_item4"],
|
||||
"source": ["ds2"] * 4,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"dataset_mixing_strategy": "weighted",
|
||||
"mixing_weights": [0.75, 0.25], # 3:1 ratio
|
||||
"seed": 42,
|
||||
}
|
||||
)
|
||||
|
||||
merged = _merge_regular_datasets([dataset1, dataset2], cfg)
|
||||
|
||||
# Should have samples proportional to weights
|
||||
assert len(merged) > 0
|
||||
assert isinstance(merged, Dataset)
|
||||
|
||||
# Count samples from each dataset
|
||||
sources = [sample["source"] for sample in merged]
|
||||
ds1_count = sources.count("ds1")
|
||||
ds2_count = sources.count("ds2")
|
||||
|
||||
# Should roughly follow the 3:1 ratio (allowing for rounding)
|
||||
assert ds1_count >= ds2_count # ds1 should have more samples
|
||||
|
||||
def test_streaming_dataset_mixing(self):
|
||||
"""Test that streaming datasets use HuggingFace interleave_datasets."""
|
||||
from axolotl.utils.data.shared import _merge_streaming_datasets
|
||||
|
||||
# Create test streaming datasets
|
||||
def gen1():
|
||||
yield {"text": "stream1_item1", "source": "stream1"}
|
||||
yield {"text": "stream1_item2", "source": "stream1"}
|
||||
|
||||
def gen2():
|
||||
yield {"text": "stream2_item1", "source": "stream2"}
|
||||
yield {"text": "stream2_item2", "source": "stream2"}
|
||||
|
||||
stream1 = IterableDataset.from_generator(gen1)
|
||||
stream2 = IterableDataset.from_generator(gen2)
|
||||
|
||||
cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42})
|
||||
|
||||
merged = _merge_streaming_datasets([stream1, stream2], cfg)
|
||||
|
||||
# Should return an IterableDataset
|
||||
assert isinstance(merged, IterableDataset)
|
||||
|
||||
# Test that we can iterate and get samples
|
||||
samples = list(merged.take(3))
|
||||
assert len(samples) >= 2 # Should get at least 2 samples
|
||||
|
||||
# Should have samples from both datasets
|
||||
sources = [sample["source"] for sample in samples]
|
||||
assert len(set(sources)) >= 1 # At least one unique source
|
||||
|
||||
def test_eval_streaming_config(self):
|
||||
"""Test eval_streaming separate from streaming config."""
|
||||
from axolotl.utils.data.sft import _is_streaming_enabled_for_split
|
||||
|
||||
# Test train streaming enabled, eval streaming disabled
|
||||
cfg = DictDefault({"streaming": True, "eval_streaming": False})
|
||||
|
||||
assert _is_streaming_enabled_for_split(cfg, "train") == True
|
||||
assert _is_streaming_enabled_for_split(cfg, "test") == False
|
||||
|
||||
# Test train streaming disabled, eval streaming enabled
|
||||
cfg2 = DictDefault({"streaming": False, "eval_streaming": True})
|
||||
|
||||
assert _is_streaming_enabled_for_split(cfg2, "train") == False
|
||||
assert _is_streaming_enabled_for_split(cfg2, "test") == True
|
||||
|
||||
def test_eval_specific_mixing_configs(self):
|
||||
"""Test eval-specific mixing configs override main configs."""
|
||||
from axolotl.utils.data.sft import _get_streaming_config_for_split
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"dataset_mixing_strategy": "round_robin",
|
||||
"mixing_weights": [0.5, 0.5],
|
||||
"eval_dataset_mixing_strategy": "weighted",
|
||||
"eval_mixing_weights": [0.8, 0.2],
|
||||
}
|
||||
)
|
||||
|
||||
# Train split should use main config
|
||||
train_cfg = _get_streaming_config_for_split(cfg, "train")
|
||||
assert train_cfg["dataset_mixing_strategy"] == "round_robin"
|
||||
assert train_cfg["mixing_weights"] == [0.5, 0.5]
|
||||
|
||||
# Test split should use eval-specific config
|
||||
test_cfg = _get_streaming_config_for_split(cfg, "test")
|
||||
assert test_cfg["dataset_mixing_strategy"] == "weighted"
|
||||
assert test_cfg["mixing_weights"] == [0.8, 0.2]
|
||||
|
||||
Reference in New Issue
Block a user