From fc3b3d1d4ec77306f5beb42a8b05f75fc035a03d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 21 Mar 2026 22:47:26 -0400 Subject: [PATCH] synthetic datasets for benchmarking and testing (#3518) [skip ci] * synthetic datasets for benchmarking and testing * fix synthetic dataset parse from config and add tests * use type=_synthetic --- src/axolotl/prompt_strategies/_synthetic.py | 96 +++++++++++++++ src/axolotl/utils/config/__init__.py | 15 ++- src/axolotl/utils/data/sft.py | 12 +- src/axolotl/utils/schemas/config.py | 17 ++- src/axolotl/utils/schemas/datasets.py | 40 ++++++- tests/patched/test_validation.py | 50 ++++++++ tests/prompt_strategies/test_synthetic.py | 125 ++++++++++++++++++++ 7 files changed, 347 insertions(+), 8 deletions(-) create mode 100644 src/axolotl/prompt_strategies/_synthetic.py create mode 100644 tests/prompt_strategies/test_synthetic.py diff --git a/src/axolotl/prompt_strategies/_synthetic.py b/src/axolotl/prompt_strategies/_synthetic.py new file mode 100644 index 000000000..1353f09ed --- /dev/null +++ b/src/axolotl/prompt_strategies/_synthetic.py @@ -0,0 +1,96 @@ +""" +Synthetic dataset generator for benchmarking and testing. + +Generates datasets with configurable sequence length, dataset size, and token ID ranges. +Useful for benchmarking memory usage and speed by sequence length, and for validating +weighted dataset mixes. + +YAML configuration example: + + datasets: + - path: synthetic + type: _synthetic + length: 1000 + sequence_length: 2048 + min_input_id: 100 + max_input_id: 32000 + seed: 42 +""" + +from typing import Any, Dict, Optional + +import numpy as np +from datasets import Dataset + +from axolotl.prompt_tokenizers import DatasetWrappingStrategy +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +class SyntheticDatasetStrategy(DatasetWrappingStrategy): + """Strategy that generates synthetic tokenized data, ignoring the source dataset.""" + + def __init__( + self, + sequence_length: int = 2048, + length: int = 1000, + min_input_id: int = 100, + max_input_id: int = 32000, + seed: Optional[int] = None, + ): + self.sequence_length = sequence_length + self.length = length + self.min_input_id = min_input_id + self.max_input_id = max_input_id + self.seed = seed + + def wrap_dataset( + self, + dataset, + process_count: int | None = None, + keep_in_memory: bool | None = False, + **kwargs, + ) -> Dataset: + LOG.info( + f"Generating synthetic dataset: {self.length} samples, " + f"sequence_length={self.sequence_length}, " + f"input_id_range=[{self.min_input_id}, {self.max_input_id})" + ) + + rng = np.random.default_rng(self.seed) + input_ids = rng.integers( + low=self.min_input_id, + high=self.max_input_id, + size=(self.length, self.sequence_length), + ).tolist() + + attention_mask = [[1] * self.sequence_length] * self.length + # labels == input_ids means we train on all tokens + labels = [row[:] for row in input_ids] + + return Dataset.from_dict( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } + ) + + +def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + ds_cfg = ds_cfg or {} + + sequence_length = ds_cfg.get("sequence_length", cfg.sequence_len) + length = ds_cfg.get("length", 1000) + min_input_id = ds_cfg.get("min_input_id", 100) + max_input_id = ds_cfg.get("max_input_id", tokenizer.vocab_size) + seed = ds_cfg.get("seed", None) + + return SyntheticDatasetStrategy( + sequence_length=sequence_length, + length=length, + min_input_id=min_input_id, + max_input_id=max_input_id, + seed=seed, + ) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 61096cb86..c5bad62de 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -22,7 +22,12 @@ from axolotl.utils.schemas.config import ( AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, AxolotlInputConfig as AxolotlInputConfigBase, ) -from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset +from axolotl.utils.schemas.datasets import ( + DPODataset, + KTODataset, + SFTDataset, + SyntheticDataset, +) LOG = get_logger(__name__) @@ -308,6 +313,14 @@ def validate_config( cfg["datasets"][idx] = DPODataset(**ds_cfg) elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset): cfg["datasets"][idx] = KTODataset(**dict(ds_cfg)) + elif ( + ds_cfg.get("type") + if isinstance(ds_cfg, dict) + else getattr(ds_cfg, "type", None) + ) == "_synthetic" and not isinstance(ds_cfg, SyntheticDataset): + cfg["datasets"][idx] = SyntheticDataset( + **(ds_cfg if isinstance(ds_cfg, dict) else dict(ds_cfg)) + ) elif not isinstance(ds_cfg, SFTDataset): cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg)) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index e008b542b..0b2ec2b5f 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -376,10 +376,14 @@ def _load_and_process_single_dataset( streaming: bool = False, ) -> tuple[Dataset | IterableDataset, Prompter | None]: """Load and process a single dataset based on the passed config.""" - # Load the dataset - dataset = load_dataset_with_config( - dataset_config, cfg.hf_use_auth_token, streaming=streaming - ) + # For synthetic datasets, create a minimal placeholder instead of loading from path + if dataset_config.type == "_synthetic": + dataset = Dataset.from_dict({"text": [""]}) + else: + # Load the dataset + dataset = load_dataset_with_config( + dataset_config, cfg.hf_use_auth_token, streaming=streaming + ) # Parse dataset type d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 97a9c923e..67dea4958 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -22,6 +22,7 @@ from axolotl.utils.schemas.datasets import ( PretrainingDataset, SFTDataset, StepwiseSupervisedDataset, + SyntheticDataset, ) from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig @@ -185,7 +186,13 @@ class AxolotlInputConfig( datasets: ( Annotated[ - list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset], + list[ + SFTDataset + | DPODataset + | KTODataset + | StepwiseSupervisedDataset + | SyntheticDataset + ], MinLen(1), ] | None @@ -198,7 +205,13 @@ class AxolotlInputConfig( test_datasets: ( Annotated[ - list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset], + list[ + SFTDataset + | DPODataset + | KTODataset + | StepwiseSupervisedDataset + | SyntheticDataset + ], MinLen(1), ] | None diff --git a/src/axolotl/utils/schemas/datasets.py b/src/axolotl/utils/schemas/datasets.py index e32468706..6114a63e0 100644 --- a/src/axolotl/utils/schemas/datasets.py +++ b/src/axolotl/utils/schemas/datasets.py @@ -296,4 +296,42 @@ class KTODataset(BaseModel): revision: str | None = None -DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset +class SyntheticDataset(BaseModel): + """Synthetic dataset configuration for benchmarking and testing. + + Generates datasets with configurable sequence length, dataset size, and token ID + ranges. Useful for benchmarking memory usage and speed by sequence length, and for + validating weighted dataset mixes. + """ + + path: Literal["synthetic"] = "synthetic" + type: Literal["_synthetic"] = "_synthetic" + length: int = Field( + default=1000, + json_schema_extra={"description": "Number of rows to generate"}, + ) + sequence_length: int | None = Field( + default=None, + json_schema_extra={ + "description": "Sequence length per row (defaults to sequence_len from config)" + }, + ) + min_input_id: int = Field( + default=100, + json_schema_extra={"description": "Minimum token ID for generation"}, + ) + max_input_id: int | None = Field( + default=None, + json_schema_extra={ + "description": "Maximum token ID for generation (defaults to tokenizer vocab_size)" + }, + ) + seed: int | None = Field( + default=None, + json_schema_extra={"description": "Random seed for reproducibility"}, + ) + + +DatasetConfig = ( + SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset | SyntheticDataset +) diff --git a/tests/patched/test_validation.py b/tests/patched/test_validation.py index d22927940..29ab859c1 100644 --- a/tests/patched/test_validation.py +++ b/tests/patched/test_validation.py @@ -13,6 +13,7 @@ from axolotl.utils.config import validate_config from axolotl.utils.dict import DictDefault from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.schemas.config import AxolotlConfigWCapabilities +from axolotl.utils.schemas.datasets import SFTDataset from axolotl.utils.wandb_ import setup_wandb_env_vars warnings.filterwarnings("error") @@ -1731,3 +1732,52 @@ class TestDataloaderValidation(BaseValidation): assert new_cfg.dataloader_num_workers == 8 assert new_cfg.dataloader_pin_memory is True assert new_cfg.dataloader_prefetch_factor == 256 + + +class TestSyntheticDatasetValidation(BaseValidation): + """ + Tests for synthetic dataset config validation + """ + + @staticmethod + def _make_cfg(minimal_cfg, datasets): + raw = dict(minimal_cfg) + raw["datasets"] = datasets + return DictDefault(raw) + + def test_synthetic_dict_config_validates(self, minimal_cfg): + """Synthetic dataset passed as a raw dict should not raise.""" + cfg = self._make_cfg( + minimal_cfg, + [ + { + "path": "synthetic", + "type": "_synthetic", + "length": 100, + "sequence_length": 64, + } + ], + ) + + new_cfg = validate_config(cfg) + assert new_cfg.datasets[0]["path"] == "synthetic" + + def test_synthetic_already_sft_does_not_crash(self, minimal_cfg): + """Synthetic dataset already parsed as SFTDataset should not raise AttributeError.""" + sft = SFTDataset(path="synthetic", type="_synthetic") + cfg = self._make_cfg(minimal_cfg, [sft]) + + # Before the fix, this raised: + # AttributeError: 'SFTDataset' object has no attribute 'get' + new_cfg = validate_config(cfg) + assert new_cfg.datasets[0]["path"] == "synthetic" + + def test_non_synthetic_sft_validates(self, minimal_cfg): + """A regular SFT dataset should validate without being treated as synthetic.""" + cfg = self._make_cfg( + minimal_cfg, + [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}], + ) + + new_cfg = validate_config(cfg) + assert new_cfg.datasets[0]["path"] == "mhenrichsen/alpaca_2k_test" diff --git a/tests/prompt_strategies/test_synthetic.py b/tests/prompt_strategies/test_synthetic.py new file mode 100644 index 000000000..6038333c1 --- /dev/null +++ b/tests/prompt_strategies/test_synthetic.py @@ -0,0 +1,125 @@ +"""Tests for the synthetic dataset generator.""" + +import unittest +from unittest.mock import MagicMock + +from datasets import Dataset + +from axolotl.prompt_strategies._synthetic import SyntheticDatasetStrategy, load +from axolotl.utils.dict import DictDefault + + +class TestSyntheticDatasetStrategy(unittest.TestCase): + def test_generates_correct_shape(self): + strategy = SyntheticDatasetStrategy( + sequence_length=128, + length=50, + min_input_id=1, + max_input_id=1000, + seed=42, + ) + dummy = Dataset.from_dict({"text": [""]}) + result = strategy.wrap_dataset(dummy) + + assert len(result) == 50 + assert len(result[0]["input_ids"]) == 128 + assert len(result[0]["attention_mask"]) == 128 + assert len(result[0]["labels"]) == 128 + + def test_attention_mask_all_ones(self): + strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0) + dummy = Dataset.from_dict({"text": [""]}) + result = strategy.wrap_dataset(dummy) + + for row in result: + assert all(v == 1 for v in row["attention_mask"]) + + def test_labels_equal_input_ids(self): + strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0) + dummy = Dataset.from_dict({"text": [""]}) + result = strategy.wrap_dataset(dummy) + + for row in result: + assert row["input_ids"] == row["labels"] + + def test_input_id_range(self): + strategy = SyntheticDatasetStrategy( + sequence_length=64, + length=100, + min_input_id=500, + max_input_id=600, + seed=42, + ) + dummy = Dataset.from_dict({"text": [""]}) + result = strategy.wrap_dataset(dummy) + + for row in result: + for token_id in row["input_ids"]: + assert 500 <= token_id < 600 + + def test_seed_reproducibility(self): + kwargs = dict( + sequence_length=64, length=20, min_input_id=1, max_input_id=1000, seed=123 + ) + dummy = Dataset.from_dict({"text": [""]}) + + result1 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy) + result2 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy) + + for r1, r2 in zip(result1, result2, strict=True): + assert r1["input_ids"] == r2["input_ids"] + + def test_different_seeds_differ(self): + common = dict(sequence_length=64, length=20, min_input_id=1, max_input_id=1000) + dummy = Dataset.from_dict({"text": [""]}) + + result1 = SyntheticDatasetStrategy(seed=1, **common).wrap_dataset(dummy) + result2 = SyntheticDatasetStrategy(seed=2, **common).wrap_dataset(dummy) + + any_different = any( + r1["input_ids"] != r2["input_ids"] + for r1, r2 in zip(result1, result2, strict=True) + ) + assert any_different + + def test_load_function_with_ds_cfg(self): + tokenizer = MagicMock() + tokenizer.vocab_size = 32000 + cfg = DictDefault({"sequence_len": 512, "train_on_inputs": False}) + ds_cfg = { + "sequence_length": 256, + "length": 5, + "min_input_id": 10, + "max_input_id": 100, + "seed": 0, + } + + strategy = load(tokenizer, cfg, ds_cfg=ds_cfg) + assert isinstance(strategy, SyntheticDatasetStrategy) + assert strategy.sequence_length == 256 + assert strategy.length == 5 + assert strategy.min_input_id == 10 + assert strategy.max_input_id == 100 + + def test_load_defaults_from_cfg(self): + tokenizer = MagicMock() + tokenizer.vocab_size = 32000 + cfg = DictDefault({"sequence_len": 1024, "train_on_inputs": False}) + + strategy = load(tokenizer, cfg, ds_cfg={}) + assert strategy.sequence_length == 1024 + assert strategy.max_input_id == 32000 + assert strategy.length == 1000 + + def test_load_with_no_ds_cfg(self): + tokenizer = MagicMock() + tokenizer.vocab_size = 50000 + cfg = DictDefault({"sequence_len": 2048, "train_on_inputs": False}) + + strategy = load(tokenizer, cfg) + assert strategy.sequence_length == 2048 + assert strategy.max_input_id == 50000 + + +if __name__ == "__main__": + unittest.main()