synthetic datasets for benchmarking and testing (#3518) [skip ci]

* synthetic datasets for benchmarking and testing

* fix synthetic dataset parse from config and add tests

* use type=_synthetic
This commit is contained in:
Wing Lian
2026-03-21 22:47:26 -04:00
committed by GitHub
parent c9df6efdc2
commit fc3b3d1d4e
7 changed files with 347 additions and 8 deletions

View File

@@ -0,0 +1,96 @@
"""
Synthetic dataset generator for benchmarking and testing.
Generates datasets with configurable sequence length, dataset size, and token ID ranges.
Useful for benchmarking memory usage and speed by sequence length, and for validating
weighted dataset mixes.
YAML configuration example:
datasets:
- path: synthetic
type: _synthetic
length: 1000
sequence_length: 2048
min_input_id: 100
max_input_id: 32000
seed: 42
"""
from typing import Any, Dict, Optional
import numpy as np
from datasets import Dataset
from axolotl.prompt_tokenizers import DatasetWrappingStrategy
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class SyntheticDatasetStrategy(DatasetWrappingStrategy):
"""Strategy that generates synthetic tokenized data, ignoring the source dataset."""
def __init__(
self,
sequence_length: int = 2048,
length: int = 1000,
min_input_id: int = 100,
max_input_id: int = 32000,
seed: Optional[int] = None,
):
self.sequence_length = sequence_length
self.length = length
self.min_input_id = min_input_id
self.max_input_id = max_input_id
self.seed = seed
def wrap_dataset(
self,
dataset,
process_count: int | None = None,
keep_in_memory: bool | None = False,
**kwargs,
) -> Dataset:
LOG.info(
f"Generating synthetic dataset: {self.length} samples, "
f"sequence_length={self.sequence_length}, "
f"input_id_range=[{self.min_input_id}, {self.max_input_id})"
)
rng = np.random.default_rng(self.seed)
input_ids = rng.integers(
low=self.min_input_id,
high=self.max_input_id,
size=(self.length, self.sequence_length),
).tolist()
attention_mask = [[1] * self.sequence_length] * self.length
# labels == input_ids means we train on all tokens
labels = [row[:] for row in input_ids]
return Dataset.from_dict(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
sequence_length = ds_cfg.get("sequence_length", cfg.sequence_len)
length = ds_cfg.get("length", 1000)
min_input_id = ds_cfg.get("min_input_id", 100)
max_input_id = ds_cfg.get("max_input_id", tokenizer.vocab_size)
seed = ds_cfg.get("seed", None)
return SyntheticDatasetStrategy(
sequence_length=sequence_length,
length=length,
min_input_id=min_input_id,
max_input_id=max_input_id,
seed=seed,
)

View File

@@ -22,7 +22,12 @@ from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
AxolotlInputConfig as AxolotlInputConfigBase,
)
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
from axolotl.utils.schemas.datasets import (
DPODataset,
KTODataset,
SFTDataset,
SyntheticDataset,
)
LOG = get_logger(__name__)
@@ -308,6 +313,14 @@ def validate_config(
cfg["datasets"][idx] = DPODataset(**ds_cfg)
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
elif (
ds_cfg.get("type")
if isinstance(ds_cfg, dict)
else getattr(ds_cfg, "type", None)
) == "_synthetic" and not isinstance(ds_cfg, SyntheticDataset):
cfg["datasets"][idx] = SyntheticDataset(
**(ds_cfg if isinstance(ds_cfg, dict) else dict(ds_cfg))
)
elif not isinstance(ds_cfg, SFTDataset):
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))

View File

@@ -376,10 +376,14 @@ def _load_and_process_single_dataset(
streaming: bool = False,
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Load and process a single dataset based on the passed config."""
# Load the dataset
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=streaming
)
# For synthetic datasets, create a minimal placeholder instead of loading from path
if dataset_config.type == "_synthetic":
dataset = Dataset.from_dict({"text": [""]})
else:
# Load the dataset
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=streaming
)
# Parse dataset type
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)

View File

@@ -22,6 +22,7 @@ from axolotl.utils.schemas.datasets import (
PretrainingDataset,
SFTDataset,
StepwiseSupervisedDataset,
SyntheticDataset,
)
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
@@ -185,7 +186,13 @@ class AxolotlInputConfig(
datasets: (
Annotated[
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
list[
SFTDataset
| DPODataset
| KTODataset
| StepwiseSupervisedDataset
| SyntheticDataset
],
MinLen(1),
]
| None
@@ -198,7 +205,13 @@ class AxolotlInputConfig(
test_datasets: (
Annotated[
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
list[
SFTDataset
| DPODataset
| KTODataset
| StepwiseSupervisedDataset
| SyntheticDataset
],
MinLen(1),
]
| None

View File

@@ -296,4 +296,42 @@ class KTODataset(BaseModel):
revision: str | None = None
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset
class SyntheticDataset(BaseModel):
"""Synthetic dataset configuration for benchmarking and testing.
Generates datasets with configurable sequence length, dataset size, and token ID
ranges. Useful for benchmarking memory usage and speed by sequence length, and for
validating weighted dataset mixes.
"""
path: Literal["synthetic"] = "synthetic"
type: Literal["_synthetic"] = "_synthetic"
length: int = Field(
default=1000,
json_schema_extra={"description": "Number of rows to generate"},
)
sequence_length: int | None = Field(
default=None,
json_schema_extra={
"description": "Sequence length per row (defaults to sequence_len from config)"
},
)
min_input_id: int = Field(
default=100,
json_schema_extra={"description": "Minimum token ID for generation"},
)
max_input_id: int | None = Field(
default=None,
json_schema_extra={
"description": "Maximum token ID for generation (defaults to tokenizer vocab_size)"
},
)
seed: int | None = Field(
default=None,
json_schema_extra={"description": "Random seed for reproducibility"},
)
DatasetConfig = (
SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset | SyntheticDataset
)

View File

@@ -13,6 +13,7 @@ from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
from axolotl.utils.schemas.datasets import SFTDataset
from axolotl.utils.wandb_ import setup_wandb_env_vars
warnings.filterwarnings("error")
@@ -1731,3 +1732,52 @@ class TestDataloaderValidation(BaseValidation):
assert new_cfg.dataloader_num_workers == 8
assert new_cfg.dataloader_pin_memory is True
assert new_cfg.dataloader_prefetch_factor == 256
class TestSyntheticDatasetValidation(BaseValidation):
"""
Tests for synthetic dataset config validation
"""
@staticmethod
def _make_cfg(minimal_cfg, datasets):
raw = dict(minimal_cfg)
raw["datasets"] = datasets
return DictDefault(raw)
def test_synthetic_dict_config_validates(self, minimal_cfg):
"""Synthetic dataset passed as a raw dict should not raise."""
cfg = self._make_cfg(
minimal_cfg,
[
{
"path": "synthetic",
"type": "_synthetic",
"length": 100,
"sequence_length": 64,
}
],
)
new_cfg = validate_config(cfg)
assert new_cfg.datasets[0]["path"] == "synthetic"
def test_synthetic_already_sft_does_not_crash(self, minimal_cfg):
"""Synthetic dataset already parsed as SFTDataset should not raise AttributeError."""
sft = SFTDataset(path="synthetic", type="_synthetic")
cfg = self._make_cfg(minimal_cfg, [sft])
# Before the fix, this raised:
# AttributeError: 'SFTDataset' object has no attribute 'get'
new_cfg = validate_config(cfg)
assert new_cfg.datasets[0]["path"] == "synthetic"
def test_non_synthetic_sft_validates(self, minimal_cfg):
"""A regular SFT dataset should validate without being treated as synthetic."""
cfg = self._make_cfg(
minimal_cfg,
[{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
)
new_cfg = validate_config(cfg)
assert new_cfg.datasets[0]["path"] == "mhenrichsen/alpaca_2k_test"

View File

@@ -0,0 +1,125 @@
"""Tests for the synthetic dataset generator."""
import unittest
from unittest.mock import MagicMock
from datasets import Dataset
from axolotl.prompt_strategies._synthetic import SyntheticDatasetStrategy, load
from axolotl.utils.dict import DictDefault
class TestSyntheticDatasetStrategy(unittest.TestCase):
def test_generates_correct_shape(self):
strategy = SyntheticDatasetStrategy(
sequence_length=128,
length=50,
min_input_id=1,
max_input_id=1000,
seed=42,
)
dummy = Dataset.from_dict({"text": [""]})
result = strategy.wrap_dataset(dummy)
assert len(result) == 50
assert len(result[0]["input_ids"]) == 128
assert len(result[0]["attention_mask"]) == 128
assert len(result[0]["labels"]) == 128
def test_attention_mask_all_ones(self):
strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0)
dummy = Dataset.from_dict({"text": [""]})
result = strategy.wrap_dataset(dummy)
for row in result:
assert all(v == 1 for v in row["attention_mask"])
def test_labels_equal_input_ids(self):
strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0)
dummy = Dataset.from_dict({"text": [""]})
result = strategy.wrap_dataset(dummy)
for row in result:
assert row["input_ids"] == row["labels"]
def test_input_id_range(self):
strategy = SyntheticDatasetStrategy(
sequence_length=64,
length=100,
min_input_id=500,
max_input_id=600,
seed=42,
)
dummy = Dataset.from_dict({"text": [""]})
result = strategy.wrap_dataset(dummy)
for row in result:
for token_id in row["input_ids"]:
assert 500 <= token_id < 600
def test_seed_reproducibility(self):
kwargs = dict(
sequence_length=64, length=20, min_input_id=1, max_input_id=1000, seed=123
)
dummy = Dataset.from_dict({"text": [""]})
result1 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy)
result2 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy)
for r1, r2 in zip(result1, result2, strict=True):
assert r1["input_ids"] == r2["input_ids"]
def test_different_seeds_differ(self):
common = dict(sequence_length=64, length=20, min_input_id=1, max_input_id=1000)
dummy = Dataset.from_dict({"text": [""]})
result1 = SyntheticDatasetStrategy(seed=1, **common).wrap_dataset(dummy)
result2 = SyntheticDatasetStrategy(seed=2, **common).wrap_dataset(dummy)
any_different = any(
r1["input_ids"] != r2["input_ids"]
for r1, r2 in zip(result1, result2, strict=True)
)
assert any_different
def test_load_function_with_ds_cfg(self):
tokenizer = MagicMock()
tokenizer.vocab_size = 32000
cfg = DictDefault({"sequence_len": 512, "train_on_inputs": False})
ds_cfg = {
"sequence_length": 256,
"length": 5,
"min_input_id": 10,
"max_input_id": 100,
"seed": 0,
}
strategy = load(tokenizer, cfg, ds_cfg=ds_cfg)
assert isinstance(strategy, SyntheticDatasetStrategy)
assert strategy.sequence_length == 256
assert strategy.length == 5
assert strategy.min_input_id == 10
assert strategy.max_input_id == 100
def test_load_defaults_from_cfg(self):
tokenizer = MagicMock()
tokenizer.vocab_size = 32000
cfg = DictDefault({"sequence_len": 1024, "train_on_inputs": False})
strategy = load(tokenizer, cfg, ds_cfg={})
assert strategy.sequence_length == 1024
assert strategy.max_input_id == 32000
assert strategy.length == 1000
def test_load_with_no_ds_cfg(self):
tokenizer = MagicMock()
tokenizer.vocab_size = 50000
cfg = DictDefault({"sequence_len": 2048, "train_on_inputs": False})
strategy = load(tokenizer, cfg)
assert strategy.sequence_length == 2048
assert strategy.max_input_id == 50000
if __name__ == "__main__":
unittest.main()