synthetic datasets for benchmarking and testing (#3518) [skip ci]

* synthetic datasets for benchmarking and testing

* fix synthetic dataset parse from config and add tests

* use type=_synthetic
This commit is contained in:
Wing Lian
2026-03-21 22:47:26 -04:00
committed by GitHub
parent c9df6efdc2
commit fc3b3d1d4e
7 changed files with 347 additions and 8 deletions

View File

@@ -13,6 +13,7 @@ from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
from axolotl.utils.schemas.datasets import SFTDataset
from axolotl.utils.wandb_ import setup_wandb_env_vars
warnings.filterwarnings("error")
@@ -1731,3 +1732,52 @@ class TestDataloaderValidation(BaseValidation):
assert new_cfg.dataloader_num_workers == 8
assert new_cfg.dataloader_pin_memory is True
assert new_cfg.dataloader_prefetch_factor == 256
class TestSyntheticDatasetValidation(BaseValidation):
"""
Tests for synthetic dataset config validation
"""
@staticmethod
def _make_cfg(minimal_cfg, datasets):
raw = dict(minimal_cfg)
raw["datasets"] = datasets
return DictDefault(raw)
def test_synthetic_dict_config_validates(self, minimal_cfg):
"""Synthetic dataset passed as a raw dict should not raise."""
cfg = self._make_cfg(
minimal_cfg,
[
{
"path": "synthetic",
"type": "_synthetic",
"length": 100,
"sequence_length": 64,
}
],
)
new_cfg = validate_config(cfg)
assert new_cfg.datasets[0]["path"] == "synthetic"
def test_synthetic_already_sft_does_not_crash(self, minimal_cfg):
"""Synthetic dataset already parsed as SFTDataset should not raise AttributeError."""
sft = SFTDataset(path="synthetic", type="_synthetic")
cfg = self._make_cfg(minimal_cfg, [sft])
# Before the fix, this raised:
# AttributeError: 'SFTDataset' object has no attribute 'get'
new_cfg = validate_config(cfg)
assert new_cfg.datasets[0]["path"] == "synthetic"
def test_non_synthetic_sft_validates(self, minimal_cfg):
"""A regular SFT dataset should validate without being treated as synthetic."""
cfg = self._make_cfg(
minimal_cfg,
[{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
)
new_cfg = validate_config(cfg)
assert new_cfg.datasets[0]["path"] == "mhenrichsen/alpaca_2k_test"

View File

@@ -0,0 +1,125 @@
"""Tests for the synthetic dataset generator."""
import unittest
from unittest.mock import MagicMock
from datasets import Dataset
from axolotl.prompt_strategies._synthetic import SyntheticDatasetStrategy, load
from axolotl.utils.dict import DictDefault
class TestSyntheticDatasetStrategy(unittest.TestCase):
def test_generates_correct_shape(self):
strategy = SyntheticDatasetStrategy(
sequence_length=128,
length=50,
min_input_id=1,
max_input_id=1000,
seed=42,
)
dummy = Dataset.from_dict({"text": [""]})
result = strategy.wrap_dataset(dummy)
assert len(result) == 50
assert len(result[0]["input_ids"]) == 128
assert len(result[0]["attention_mask"]) == 128
assert len(result[0]["labels"]) == 128
def test_attention_mask_all_ones(self):
strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0)
dummy = Dataset.from_dict({"text": [""]})
result = strategy.wrap_dataset(dummy)
for row in result:
assert all(v == 1 for v in row["attention_mask"])
def test_labels_equal_input_ids(self):
strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0)
dummy = Dataset.from_dict({"text": [""]})
result = strategy.wrap_dataset(dummy)
for row in result:
assert row["input_ids"] == row["labels"]
def test_input_id_range(self):
strategy = SyntheticDatasetStrategy(
sequence_length=64,
length=100,
min_input_id=500,
max_input_id=600,
seed=42,
)
dummy = Dataset.from_dict({"text": [""]})
result = strategy.wrap_dataset(dummy)
for row in result:
for token_id in row["input_ids"]:
assert 500 <= token_id < 600
def test_seed_reproducibility(self):
kwargs = dict(
sequence_length=64, length=20, min_input_id=1, max_input_id=1000, seed=123
)
dummy = Dataset.from_dict({"text": [""]})
result1 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy)
result2 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy)
for r1, r2 in zip(result1, result2, strict=True):
assert r1["input_ids"] == r2["input_ids"]
def test_different_seeds_differ(self):
common = dict(sequence_length=64, length=20, min_input_id=1, max_input_id=1000)
dummy = Dataset.from_dict({"text": [""]})
result1 = SyntheticDatasetStrategy(seed=1, **common).wrap_dataset(dummy)
result2 = SyntheticDatasetStrategy(seed=2, **common).wrap_dataset(dummy)
any_different = any(
r1["input_ids"] != r2["input_ids"]
for r1, r2 in zip(result1, result2, strict=True)
)
assert any_different
def test_load_function_with_ds_cfg(self):
tokenizer = MagicMock()
tokenizer.vocab_size = 32000
cfg = DictDefault({"sequence_len": 512, "train_on_inputs": False})
ds_cfg = {
"sequence_length": 256,
"length": 5,
"min_input_id": 10,
"max_input_id": 100,
"seed": 0,
}
strategy = load(tokenizer, cfg, ds_cfg=ds_cfg)
assert isinstance(strategy, SyntheticDatasetStrategy)
assert strategy.sequence_length == 256
assert strategy.length == 5
assert strategy.min_input_id == 10
assert strategy.max_input_id == 100
def test_load_defaults_from_cfg(self):
tokenizer = MagicMock()
tokenizer.vocab_size = 32000
cfg = DictDefault({"sequence_len": 1024, "train_on_inputs": False})
strategy = load(tokenizer, cfg, ds_cfg={})
assert strategy.sequence_length == 1024
assert strategy.max_input_id == 32000
assert strategy.length == 1000
def test_load_with_no_ds_cfg(self):
tokenizer = MagicMock()
tokenizer.vocab_size = 50000
cfg = DictDefault({"sequence_len": 2048, "train_on_inputs": False})
strategy = load(tokenizer, cfg)
assert strategy.sequence_length == 2048
assert strategy.max_input_id == 50000
if __name__ == "__main__":
unittest.main()