synthetic datasets for benchmarking and testing (#3518) [skip ci]

* synthetic datasets for benchmarking and testing

* fix synthetic dataset parse from config and add tests

* use type=_synthetic
This commit is contained in:
Wing Lian
2026-03-21 22:47:26 -04:00
committed by GitHub
parent c9df6efdc2
commit fc3b3d1d4e
7 changed files with 347 additions and 8 deletions

View File

@@ -0,0 +1,96 @@
"""
Synthetic dataset generator for benchmarking and testing.
Generates datasets with configurable sequence length, dataset size, and token ID ranges.
Useful for benchmarking memory usage and speed by sequence length, and for validating
weighted dataset mixes.
YAML configuration example:
datasets:
- path: synthetic
type: _synthetic
length: 1000
sequence_length: 2048
min_input_id: 100
max_input_id: 32000
seed: 42
"""
from typing import Any, Dict, Optional
import numpy as np
from datasets import Dataset
from axolotl.prompt_tokenizers import DatasetWrappingStrategy
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class SyntheticDatasetStrategy(DatasetWrappingStrategy):
"""Strategy that generates synthetic tokenized data, ignoring the source dataset."""
def __init__(
self,
sequence_length: int = 2048,
length: int = 1000,
min_input_id: int = 100,
max_input_id: int = 32000,
seed: Optional[int] = None,
):
self.sequence_length = sequence_length
self.length = length
self.min_input_id = min_input_id
self.max_input_id = max_input_id
self.seed = seed
def wrap_dataset(
self,
dataset,
process_count: int | None = None,
keep_in_memory: bool | None = False,
**kwargs,
) -> Dataset:
LOG.info(
f"Generating synthetic dataset: {self.length} samples, "
f"sequence_length={self.sequence_length}, "
f"input_id_range=[{self.min_input_id}, {self.max_input_id})"
)
rng = np.random.default_rng(self.seed)
input_ids = rng.integers(
low=self.min_input_id,
high=self.max_input_id,
size=(self.length, self.sequence_length),
).tolist()
attention_mask = [[1] * self.sequence_length] * self.length
# labels == input_ids means we train on all tokens
labels = [row[:] for row in input_ids]
return Dataset.from_dict(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
sequence_length = ds_cfg.get("sequence_length", cfg.sequence_len)
length = ds_cfg.get("length", 1000)
min_input_id = ds_cfg.get("min_input_id", 100)
max_input_id = ds_cfg.get("max_input_id", tokenizer.vocab_size)
seed = ds_cfg.get("seed", None)
return SyntheticDatasetStrategy(
sequence_length=sequence_length,
length=length,
min_input_id=min_input_id,
max_input_id=max_input_id,
seed=seed,
)

View File

@@ -22,7 +22,12 @@ from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
AxolotlInputConfig as AxolotlInputConfigBase,
)
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
from axolotl.utils.schemas.datasets import (
DPODataset,
KTODataset,
SFTDataset,
SyntheticDataset,
)
LOG = get_logger(__name__)
@@ -308,6 +313,14 @@ def validate_config(
cfg["datasets"][idx] = DPODataset(**ds_cfg)
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
elif (
ds_cfg.get("type")
if isinstance(ds_cfg, dict)
else getattr(ds_cfg, "type", None)
) == "_synthetic" and not isinstance(ds_cfg, SyntheticDataset):
cfg["datasets"][idx] = SyntheticDataset(
**(ds_cfg if isinstance(ds_cfg, dict) else dict(ds_cfg))
)
elif not isinstance(ds_cfg, SFTDataset):
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))

View File

@@ -376,10 +376,14 @@ def _load_and_process_single_dataset(
streaming: bool = False,
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Load and process a single dataset based on the passed config."""
# Load the dataset
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=streaming
)
# For synthetic datasets, create a minimal placeholder instead of loading from path
if dataset_config.type == "_synthetic":
dataset = Dataset.from_dict({"text": [""]})
else:
# Load the dataset
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=streaming
)
# Parse dataset type
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)

View File

@@ -22,6 +22,7 @@ from axolotl.utils.schemas.datasets import (
PretrainingDataset,
SFTDataset,
StepwiseSupervisedDataset,
SyntheticDataset,
)
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
@@ -185,7 +186,13 @@ class AxolotlInputConfig(
datasets: (
Annotated[
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
list[
SFTDataset
| DPODataset
| KTODataset
| StepwiseSupervisedDataset
| SyntheticDataset
],
MinLen(1),
]
| None
@@ -198,7 +205,13 @@ class AxolotlInputConfig(
test_datasets: (
Annotated[
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
list[
SFTDataset
| DPODataset
| KTODataset
| StepwiseSupervisedDataset
| SyntheticDataset
],
MinLen(1),
]
| None

View File

@@ -296,4 +296,42 @@ class KTODataset(BaseModel):
revision: str | None = None
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset
class SyntheticDataset(BaseModel):
"""Synthetic dataset configuration for benchmarking and testing.
Generates datasets with configurable sequence length, dataset size, and token ID
ranges. Useful for benchmarking memory usage and speed by sequence length, and for
validating weighted dataset mixes.
"""
path: Literal["synthetic"] = "synthetic"
type: Literal["_synthetic"] = "_synthetic"
length: int = Field(
default=1000,
json_schema_extra={"description": "Number of rows to generate"},
)
sequence_length: int | None = Field(
default=None,
json_schema_extra={
"description": "Sequence length per row (defaults to sequence_len from config)"
},
)
min_input_id: int = Field(
default=100,
json_schema_extra={"description": "Minimum token ID for generation"},
)
max_input_id: int | None = Field(
default=None,
json_schema_extra={
"description": "Maximum token ID for generation (defaults to tokenizer vocab_size)"
},
)
seed: int | None = Field(
default=None,
json_schema_extra={"description": "Random seed for reproducibility"},
)
DatasetConfig = (
SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset | SyntheticDataset
)