synthetic datasets for benchmarking and testing (#3518) [skip ci]
* synthetic datasets for benchmarking and testing * fix synthetic dataset parse from config and add tests * use type=_synthetic
This commit is contained in:
96
src/axolotl/prompt_strategies/_synthetic.py
Normal file
96
src/axolotl/prompt_strategies/_synthetic.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
"""
|
||||||
|
Synthetic dataset generator for benchmarking and testing.
|
||||||
|
|
||||||
|
Generates datasets with configurable sequence length, dataset size, and token ID ranges.
|
||||||
|
Useful for benchmarking memory usage and speed by sequence length, and for validating
|
||||||
|
weighted dataset mixes.
|
||||||
|
|
||||||
|
YAML configuration example:
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: synthetic
|
||||||
|
type: _synthetic
|
||||||
|
length: 1000
|
||||||
|
sequence_length: 2048
|
||||||
|
min_input_id: 100
|
||||||
|
max_input_id: 32000
|
||||||
|
seed: 42
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
from axolotl.prompt_tokenizers import DatasetWrappingStrategy
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SyntheticDatasetStrategy(DatasetWrappingStrategy):
|
||||||
|
"""Strategy that generates synthetic tokenized data, ignoring the source dataset."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sequence_length: int = 2048,
|
||||||
|
length: int = 1000,
|
||||||
|
min_input_id: int = 100,
|
||||||
|
max_input_id: int = 32000,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.sequence_length = sequence_length
|
||||||
|
self.length = length
|
||||||
|
self.min_input_id = min_input_id
|
||||||
|
self.max_input_id = max_input_id
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def wrap_dataset(
|
||||||
|
self,
|
||||||
|
dataset,
|
||||||
|
process_count: int | None = None,
|
||||||
|
keep_in_memory: bool | None = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dataset:
|
||||||
|
LOG.info(
|
||||||
|
f"Generating synthetic dataset: {self.length} samples, "
|
||||||
|
f"sequence_length={self.sequence_length}, "
|
||||||
|
f"input_id_range=[{self.min_input_id}, {self.max_input_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
rng = np.random.default_rng(self.seed)
|
||||||
|
input_ids = rng.integers(
|
||||||
|
low=self.min_input_id,
|
||||||
|
high=self.max_input_id,
|
||||||
|
size=(self.length, self.sequence_length),
|
||||||
|
).tolist()
|
||||||
|
|
||||||
|
attention_mask = [[1] * self.sequence_length] * self.length
|
||||||
|
# labels == input_ids means we train on all tokens
|
||||||
|
labels = [row[:] for row in input_ids]
|
||||||
|
|
||||||
|
return Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"labels": labels,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
|
ds_cfg = ds_cfg or {}
|
||||||
|
|
||||||
|
sequence_length = ds_cfg.get("sequence_length", cfg.sequence_len)
|
||||||
|
length = ds_cfg.get("length", 1000)
|
||||||
|
min_input_id = ds_cfg.get("min_input_id", 100)
|
||||||
|
max_input_id = ds_cfg.get("max_input_id", tokenizer.vocab_size)
|
||||||
|
seed = ds_cfg.get("seed", None)
|
||||||
|
|
||||||
|
return SyntheticDatasetStrategy(
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
length=length,
|
||||||
|
min_input_id=min_input_id,
|
||||||
|
max_input_id=max_input_id,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
@@ -22,7 +22,12 @@ from axolotl.utils.schemas.config import (
|
|||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||||
)
|
)
|
||||||
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
|
from axolotl.utils.schemas.datasets import (
|
||||||
|
DPODataset,
|
||||||
|
KTODataset,
|
||||||
|
SFTDataset,
|
||||||
|
SyntheticDataset,
|
||||||
|
)
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
@@ -308,6 +313,14 @@ def validate_config(
|
|||||||
cfg["datasets"][idx] = DPODataset(**ds_cfg)
|
cfg["datasets"][idx] = DPODataset(**ds_cfg)
|
||||||
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
|
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
|
||||||
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
|
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
|
||||||
|
elif (
|
||||||
|
ds_cfg.get("type")
|
||||||
|
if isinstance(ds_cfg, dict)
|
||||||
|
else getattr(ds_cfg, "type", None)
|
||||||
|
) == "_synthetic" and not isinstance(ds_cfg, SyntheticDataset):
|
||||||
|
cfg["datasets"][idx] = SyntheticDataset(
|
||||||
|
**(ds_cfg if isinstance(ds_cfg, dict) else dict(ds_cfg))
|
||||||
|
)
|
||||||
elif not isinstance(ds_cfg, SFTDataset):
|
elif not isinstance(ds_cfg, SFTDataset):
|
||||||
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))
|
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))
|
||||||
|
|
||||||
|
|||||||
@@ -376,10 +376,14 @@ def _load_and_process_single_dataset(
|
|||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
||||||
"""Load and process a single dataset based on the passed config."""
|
"""Load and process a single dataset based on the passed config."""
|
||||||
# Load the dataset
|
# For synthetic datasets, create a minimal placeholder instead of loading from path
|
||||||
dataset = load_dataset_with_config(
|
if dataset_config.type == "_synthetic":
|
||||||
dataset_config, cfg.hf_use_auth_token, streaming=streaming
|
dataset = Dataset.from_dict({"text": [""]})
|
||||||
)
|
else:
|
||||||
|
# Load the dataset
|
||||||
|
dataset = load_dataset_with_config(
|
||||||
|
dataset_config, cfg.hf_use_auth_token, streaming=streaming
|
||||||
|
)
|
||||||
|
|
||||||
# Parse dataset type
|
# Parse dataset type
|
||||||
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
|
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from axolotl.utils.schemas.datasets import (
|
|||||||
PretrainingDataset,
|
PretrainingDataset,
|
||||||
SFTDataset,
|
SFTDataset,
|
||||||
StepwiseSupervisedDataset,
|
StepwiseSupervisedDataset,
|
||||||
|
SyntheticDataset,
|
||||||
)
|
)
|
||||||
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
|
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
|
||||||
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
|
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
|
||||||
@@ -185,7 +186,13 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
datasets: (
|
datasets: (
|
||||||
Annotated[
|
Annotated[
|
||||||
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
|
list[
|
||||||
|
SFTDataset
|
||||||
|
| DPODataset
|
||||||
|
| KTODataset
|
||||||
|
| StepwiseSupervisedDataset
|
||||||
|
| SyntheticDataset
|
||||||
|
],
|
||||||
MinLen(1),
|
MinLen(1),
|
||||||
]
|
]
|
||||||
| None
|
| None
|
||||||
@@ -198,7 +205,13 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
test_datasets: (
|
test_datasets: (
|
||||||
Annotated[
|
Annotated[
|
||||||
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
|
list[
|
||||||
|
SFTDataset
|
||||||
|
| DPODataset
|
||||||
|
| KTODataset
|
||||||
|
| StepwiseSupervisedDataset
|
||||||
|
| SyntheticDataset
|
||||||
|
],
|
||||||
MinLen(1),
|
MinLen(1),
|
||||||
]
|
]
|
||||||
| None
|
| None
|
||||||
|
|||||||
@@ -296,4 +296,42 @@ class KTODataset(BaseModel):
|
|||||||
revision: str | None = None
|
revision: str | None = None
|
||||||
|
|
||||||
|
|
||||||
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset
|
class SyntheticDataset(BaseModel):
|
||||||
|
"""Synthetic dataset configuration for benchmarking and testing.
|
||||||
|
|
||||||
|
Generates datasets with configurable sequence length, dataset size, and token ID
|
||||||
|
ranges. Useful for benchmarking memory usage and speed by sequence length, and for
|
||||||
|
validating weighted dataset mixes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
path: Literal["synthetic"] = "synthetic"
|
||||||
|
type: Literal["_synthetic"] = "_synthetic"
|
||||||
|
length: int = Field(
|
||||||
|
default=1000,
|
||||||
|
json_schema_extra={"description": "Number of rows to generate"},
|
||||||
|
)
|
||||||
|
sequence_length: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Sequence length per row (defaults to sequence_len from config)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
min_input_id: int = Field(
|
||||||
|
default=100,
|
||||||
|
json_schema_extra={"description": "Minimum token ID for generation"},
|
||||||
|
)
|
||||||
|
max_input_id: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Maximum token ID for generation (defaults to tokenizer vocab_size)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
seed: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Random seed for reproducibility"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DatasetConfig = (
|
||||||
|
SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset | SyntheticDataset
|
||||||
|
)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from axolotl.utils.config import validate_config
|
|||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||||
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
||||||
|
from axolotl.utils.schemas.datasets import SFTDataset
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
warnings.filterwarnings("error")
|
warnings.filterwarnings("error")
|
||||||
@@ -1731,3 +1732,52 @@ class TestDataloaderValidation(BaseValidation):
|
|||||||
assert new_cfg.dataloader_num_workers == 8
|
assert new_cfg.dataloader_num_workers == 8
|
||||||
assert new_cfg.dataloader_pin_memory is True
|
assert new_cfg.dataloader_pin_memory is True
|
||||||
assert new_cfg.dataloader_prefetch_factor == 256
|
assert new_cfg.dataloader_prefetch_factor == 256
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyntheticDatasetValidation(BaseValidation):
|
||||||
|
"""
|
||||||
|
Tests for synthetic dataset config validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_cfg(minimal_cfg, datasets):
|
||||||
|
raw = dict(minimal_cfg)
|
||||||
|
raw["datasets"] = datasets
|
||||||
|
return DictDefault(raw)
|
||||||
|
|
||||||
|
def test_synthetic_dict_config_validates(self, minimal_cfg):
|
||||||
|
"""Synthetic dataset passed as a raw dict should not raise."""
|
||||||
|
cfg = self._make_cfg(
|
||||||
|
minimal_cfg,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"path": "synthetic",
|
||||||
|
"type": "_synthetic",
|
||||||
|
"length": 100,
|
||||||
|
"sequence_length": 64,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
new_cfg = validate_config(cfg)
|
||||||
|
assert new_cfg.datasets[0]["path"] == "synthetic"
|
||||||
|
|
||||||
|
def test_synthetic_already_sft_does_not_crash(self, minimal_cfg):
|
||||||
|
"""Synthetic dataset already parsed as SFTDataset should not raise AttributeError."""
|
||||||
|
sft = SFTDataset(path="synthetic", type="_synthetic")
|
||||||
|
cfg = self._make_cfg(minimal_cfg, [sft])
|
||||||
|
|
||||||
|
# Before the fix, this raised:
|
||||||
|
# AttributeError: 'SFTDataset' object has no attribute 'get'
|
||||||
|
new_cfg = validate_config(cfg)
|
||||||
|
assert new_cfg.datasets[0]["path"] == "synthetic"
|
||||||
|
|
||||||
|
def test_non_synthetic_sft_validates(self, minimal_cfg):
|
||||||
|
"""A regular SFT dataset should validate without being treated as synthetic."""
|
||||||
|
cfg = self._make_cfg(
|
||||||
|
minimal_cfg,
|
||||||
|
[{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
new_cfg = validate_config(cfg)
|
||||||
|
assert new_cfg.datasets[0]["path"] == "mhenrichsen/alpaca_2k_test"
|
||||||
|
|||||||
125
tests/prompt_strategies/test_synthetic.py
Normal file
125
tests/prompt_strategies/test_synthetic.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""Tests for the synthetic dataset generator."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
from axolotl.prompt_strategies._synthetic import SyntheticDatasetStrategy, load
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyntheticDatasetStrategy(unittest.TestCase):
|
||||||
|
def test_generates_correct_shape(self):
|
||||||
|
strategy = SyntheticDatasetStrategy(
|
||||||
|
sequence_length=128,
|
||||||
|
length=50,
|
||||||
|
min_input_id=1,
|
||||||
|
max_input_id=1000,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
dummy = Dataset.from_dict({"text": [""]})
|
||||||
|
result = strategy.wrap_dataset(dummy)
|
||||||
|
|
||||||
|
assert len(result) == 50
|
||||||
|
assert len(result[0]["input_ids"]) == 128
|
||||||
|
assert len(result[0]["attention_mask"]) == 128
|
||||||
|
assert len(result[0]["labels"]) == 128
|
||||||
|
|
||||||
|
def test_attention_mask_all_ones(self):
|
||||||
|
strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0)
|
||||||
|
dummy = Dataset.from_dict({"text": [""]})
|
||||||
|
result = strategy.wrap_dataset(dummy)
|
||||||
|
|
||||||
|
for row in result:
|
||||||
|
assert all(v == 1 for v in row["attention_mask"])
|
||||||
|
|
||||||
|
def test_labels_equal_input_ids(self):
|
||||||
|
strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0)
|
||||||
|
dummy = Dataset.from_dict({"text": [""]})
|
||||||
|
result = strategy.wrap_dataset(dummy)
|
||||||
|
|
||||||
|
for row in result:
|
||||||
|
assert row["input_ids"] == row["labels"]
|
||||||
|
|
||||||
|
def test_input_id_range(self):
|
||||||
|
strategy = SyntheticDatasetStrategy(
|
||||||
|
sequence_length=64,
|
||||||
|
length=100,
|
||||||
|
min_input_id=500,
|
||||||
|
max_input_id=600,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
dummy = Dataset.from_dict({"text": [""]})
|
||||||
|
result = strategy.wrap_dataset(dummy)
|
||||||
|
|
||||||
|
for row in result:
|
||||||
|
for token_id in row["input_ids"]:
|
||||||
|
assert 500 <= token_id < 600
|
||||||
|
|
||||||
|
def test_seed_reproducibility(self):
|
||||||
|
kwargs = dict(
|
||||||
|
sequence_length=64, length=20, min_input_id=1, max_input_id=1000, seed=123
|
||||||
|
)
|
||||||
|
dummy = Dataset.from_dict({"text": [""]})
|
||||||
|
|
||||||
|
result1 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy)
|
||||||
|
result2 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy)
|
||||||
|
|
||||||
|
for r1, r2 in zip(result1, result2, strict=True):
|
||||||
|
assert r1["input_ids"] == r2["input_ids"]
|
||||||
|
|
||||||
|
def test_different_seeds_differ(self):
|
||||||
|
common = dict(sequence_length=64, length=20, min_input_id=1, max_input_id=1000)
|
||||||
|
dummy = Dataset.from_dict({"text": [""]})
|
||||||
|
|
||||||
|
result1 = SyntheticDatasetStrategy(seed=1, **common).wrap_dataset(dummy)
|
||||||
|
result2 = SyntheticDatasetStrategy(seed=2, **common).wrap_dataset(dummy)
|
||||||
|
|
||||||
|
any_different = any(
|
||||||
|
r1["input_ids"] != r2["input_ids"]
|
||||||
|
for r1, r2 in zip(result1, result2, strict=True)
|
||||||
|
)
|
||||||
|
assert any_different
|
||||||
|
|
||||||
|
def test_load_function_with_ds_cfg(self):
|
||||||
|
tokenizer = MagicMock()
|
||||||
|
tokenizer.vocab_size = 32000
|
||||||
|
cfg = DictDefault({"sequence_len": 512, "train_on_inputs": False})
|
||||||
|
ds_cfg = {
|
||||||
|
"sequence_length": 256,
|
||||||
|
"length": 5,
|
||||||
|
"min_input_id": 10,
|
||||||
|
"max_input_id": 100,
|
||||||
|
"seed": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
strategy = load(tokenizer, cfg, ds_cfg=ds_cfg)
|
||||||
|
assert isinstance(strategy, SyntheticDatasetStrategy)
|
||||||
|
assert strategy.sequence_length == 256
|
||||||
|
assert strategy.length == 5
|
||||||
|
assert strategy.min_input_id == 10
|
||||||
|
assert strategy.max_input_id == 100
|
||||||
|
|
||||||
|
def test_load_defaults_from_cfg(self):
|
||||||
|
tokenizer = MagicMock()
|
||||||
|
tokenizer.vocab_size = 32000
|
||||||
|
cfg = DictDefault({"sequence_len": 1024, "train_on_inputs": False})
|
||||||
|
|
||||||
|
strategy = load(tokenizer, cfg, ds_cfg={})
|
||||||
|
assert strategy.sequence_length == 1024
|
||||||
|
assert strategy.max_input_id == 32000
|
||||||
|
assert strategy.length == 1000
|
||||||
|
|
||||||
|
def test_load_with_no_ds_cfg(self):
|
||||||
|
tokenizer = MagicMock()
|
||||||
|
tokenizer.vocab_size = 50000
|
||||||
|
cfg = DictDefault({"sequence_len": 2048, "train_on_inputs": False})
|
||||||
|
|
||||||
|
strategy = load(tokenizer, cfg)
|
||||||
|
assert strategy.sequence_length == 2048
|
||||||
|
assert strategy.max_input_id == 50000
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user