Files
axolotl/tests/prompt_strategies/test_synthetic.py
Wing Lian fc3b3d1d4e synthetic datasets for benchmarking and testing (#3518) [skip ci]
* synthetic datasets for benchmarking and testing

* fix synthetic dataset parse from config and add tests

* use type=_synthetic
2026-03-21 22:47:26 -04:00

126 lines
4.2 KiB
Python

"""Tests for the synthetic dataset generator."""
import unittest
from unittest.mock import MagicMock
from datasets import Dataset
from axolotl.prompt_strategies._synthetic import SyntheticDatasetStrategy, load
from axolotl.utils.dict import DictDefault
class TestSyntheticDatasetStrategy(unittest.TestCase):
def test_generates_correct_shape(self):
strategy = SyntheticDatasetStrategy(
sequence_length=128,
length=50,
min_input_id=1,
max_input_id=1000,
seed=42,
)
dummy = Dataset.from_dict({"text": [""]})
result = strategy.wrap_dataset(dummy)
assert len(result) == 50
assert len(result[0]["input_ids"]) == 128
assert len(result[0]["attention_mask"]) == 128
assert len(result[0]["labels"]) == 128
def test_attention_mask_all_ones(self):
strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0)
dummy = Dataset.from_dict({"text": [""]})
result = strategy.wrap_dataset(dummy)
for row in result:
assert all(v == 1 for v in row["attention_mask"])
def test_labels_equal_input_ids(self):
strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0)
dummy = Dataset.from_dict({"text": [""]})
result = strategy.wrap_dataset(dummy)
for row in result:
assert row["input_ids"] == row["labels"]
def test_input_id_range(self):
strategy = SyntheticDatasetStrategy(
sequence_length=64,
length=100,
min_input_id=500,
max_input_id=600,
seed=42,
)
dummy = Dataset.from_dict({"text": [""]})
result = strategy.wrap_dataset(dummy)
for row in result:
for token_id in row["input_ids"]:
assert 500 <= token_id < 600
def test_seed_reproducibility(self):
kwargs = dict(
sequence_length=64, length=20, min_input_id=1, max_input_id=1000, seed=123
)
dummy = Dataset.from_dict({"text": [""]})
result1 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy)
result2 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy)
for r1, r2 in zip(result1, result2, strict=True):
assert r1["input_ids"] == r2["input_ids"]
def test_different_seeds_differ(self):
common = dict(sequence_length=64, length=20, min_input_id=1, max_input_id=1000)
dummy = Dataset.from_dict({"text": [""]})
result1 = SyntheticDatasetStrategy(seed=1, **common).wrap_dataset(dummy)
result2 = SyntheticDatasetStrategy(seed=2, **common).wrap_dataset(dummy)
any_different = any(
r1["input_ids"] != r2["input_ids"]
for r1, r2 in zip(result1, result2, strict=True)
)
assert any_different
def test_load_function_with_ds_cfg(self):
tokenizer = MagicMock()
tokenizer.vocab_size = 32000
cfg = DictDefault({"sequence_len": 512, "train_on_inputs": False})
ds_cfg = {
"sequence_length": 256,
"length": 5,
"min_input_id": 10,
"max_input_id": 100,
"seed": 0,
}
strategy = load(tokenizer, cfg, ds_cfg=ds_cfg)
assert isinstance(strategy, SyntheticDatasetStrategy)
assert strategy.sequence_length == 256
assert strategy.length == 5
assert strategy.min_input_id == 10
assert strategy.max_input_id == 100
def test_load_defaults_from_cfg(self):
tokenizer = MagicMock()
tokenizer.vocab_size = 32000
cfg = DictDefault({"sequence_len": 1024, "train_on_inputs": False})
strategy = load(tokenizer, cfg, ds_cfg={})
assert strategy.sequence_length == 1024
assert strategy.max_input_id == 32000
assert strategy.length == 1000
def test_load_with_no_ds_cfg(self):
tokenizer = MagicMock()
tokenizer.vocab_size = 50000
cfg = DictDefault({"sequence_len": 2048, "train_on_inputs": False})
strategy = load(tokenizer, cfg)
assert strategy.sequence_length == 2048
assert strategy.max_input_id == 50000
if __name__ == "__main__":
unittest.main()