synthetic datasets for benchmarking and testing (#3518) [skip ci]
* synthetic datasets for benchmarking and testing * fix synthetic dataset parse from config and add tests * use type=_synthetic
This commit is contained in:
96
src/axolotl/prompt_strategies/_synthetic.py
Normal file
96
src/axolotl/prompt_strategies/_synthetic.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
Synthetic dataset generator for benchmarking and testing.
|
||||
|
||||
Generates datasets with configurable sequence length, dataset size, and token ID ranges.
|
||||
Useful for benchmarking memory usage and speed by sequence length, and for validating
|
||||
weighted dataset mixes.
|
||||
|
||||
YAML configuration example:
|
||||
|
||||
datasets:
|
||||
- path: synthetic
|
||||
type: _synthetic
|
||||
length: 1000
|
||||
sequence_length: 2048
|
||||
min_input_id: 100
|
||||
max_input_id: 32000
|
||||
seed: 42
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from datasets import Dataset
|
||||
|
||||
from axolotl.prompt_tokenizers import DatasetWrappingStrategy
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class SyntheticDatasetStrategy(DatasetWrappingStrategy):
|
||||
"""Strategy that generates synthetic tokenized data, ignoring the source dataset."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sequence_length: int = 2048,
|
||||
length: int = 1000,
|
||||
min_input_id: int = 100,
|
||||
max_input_id: int = 32000,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
self.sequence_length = sequence_length
|
||||
self.length = length
|
||||
self.min_input_id = min_input_id
|
||||
self.max_input_id = max_input_id
|
||||
self.seed = seed
|
||||
|
||||
def wrap_dataset(
|
||||
self,
|
||||
dataset,
|
||||
process_count: int | None = None,
|
||||
keep_in_memory: bool | None = False,
|
||||
**kwargs,
|
||||
) -> Dataset:
|
||||
LOG.info(
|
||||
f"Generating synthetic dataset: {self.length} samples, "
|
||||
f"sequence_length={self.sequence_length}, "
|
||||
f"input_id_range=[{self.min_input_id}, {self.max_input_id})"
|
||||
)
|
||||
|
||||
rng = np.random.default_rng(self.seed)
|
||||
input_ids = rng.integers(
|
||||
low=self.min_input_id,
|
||||
high=self.max_input_id,
|
||||
size=(self.length, self.sequence_length),
|
||||
).tolist()
|
||||
|
||||
attention_mask = [[1] * self.sequence_length] * self.length
|
||||
# labels == input_ids means we train on all tokens
|
||||
labels = [row[:] for row in input_ids]
|
||||
|
||||
return Dataset.from_dict(
|
||||
{
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
ds_cfg = ds_cfg or {}
|
||||
|
||||
sequence_length = ds_cfg.get("sequence_length", cfg.sequence_len)
|
||||
length = ds_cfg.get("length", 1000)
|
||||
min_input_id = ds_cfg.get("min_input_id", 100)
|
||||
max_input_id = ds_cfg.get("max_input_id", tokenizer.vocab_size)
|
||||
seed = ds_cfg.get("seed", None)
|
||||
|
||||
return SyntheticDatasetStrategy(
|
||||
sequence_length=sequence_length,
|
||||
length=length,
|
||||
min_input_id=min_input_id,
|
||||
max_input_id=max_input_id,
|
||||
seed=seed,
|
||||
)
|
||||
@@ -22,7 +22,12 @@ from axolotl.utils.schemas.config import (
|
||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||
)
|
||||
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
|
||||
from axolotl.utils.schemas.datasets import (
|
||||
DPODataset,
|
||||
KTODataset,
|
||||
SFTDataset,
|
||||
SyntheticDataset,
|
||||
)
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -308,6 +313,14 @@ def validate_config(
|
||||
cfg["datasets"][idx] = DPODataset(**ds_cfg)
|
||||
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
|
||||
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
|
||||
elif (
|
||||
ds_cfg.get("type")
|
||||
if isinstance(ds_cfg, dict)
|
||||
else getattr(ds_cfg, "type", None)
|
||||
) == "_synthetic" and not isinstance(ds_cfg, SyntheticDataset):
|
||||
cfg["datasets"][idx] = SyntheticDataset(
|
||||
**(ds_cfg if isinstance(ds_cfg, dict) else dict(ds_cfg))
|
||||
)
|
||||
elif not isinstance(ds_cfg, SFTDataset):
|
||||
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))
|
||||
|
||||
|
||||
@@ -376,10 +376,14 @@ def _load_and_process_single_dataset(
|
||||
streaming: bool = False,
|
||||
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
||||
"""Load and process a single dataset based on the passed config."""
|
||||
# Load the dataset
|
||||
dataset = load_dataset_with_config(
|
||||
dataset_config, cfg.hf_use_auth_token, streaming=streaming
|
||||
)
|
||||
# For synthetic datasets, create a minimal placeholder instead of loading from path
|
||||
if dataset_config.type == "_synthetic":
|
||||
dataset = Dataset.from_dict({"text": [""]})
|
||||
else:
|
||||
# Load the dataset
|
||||
dataset = load_dataset_with_config(
|
||||
dataset_config, cfg.hf_use_auth_token, streaming=streaming
|
||||
)
|
||||
|
||||
# Parse dataset type
|
||||
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
|
||||
|
||||
@@ -22,6 +22,7 @@ from axolotl.utils.schemas.datasets import (
|
||||
PretrainingDataset,
|
||||
SFTDataset,
|
||||
StepwiseSupervisedDataset,
|
||||
SyntheticDataset,
|
||||
)
|
||||
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
|
||||
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
|
||||
@@ -185,7 +186,13 @@ class AxolotlInputConfig(
|
||||
|
||||
datasets: (
|
||||
Annotated[
|
||||
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
|
||||
list[
|
||||
SFTDataset
|
||||
| DPODataset
|
||||
| KTODataset
|
||||
| StepwiseSupervisedDataset
|
||||
| SyntheticDataset
|
||||
],
|
||||
MinLen(1),
|
||||
]
|
||||
| None
|
||||
@@ -198,7 +205,13 @@ class AxolotlInputConfig(
|
||||
|
||||
test_datasets: (
|
||||
Annotated[
|
||||
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
|
||||
list[
|
||||
SFTDataset
|
||||
| DPODataset
|
||||
| KTODataset
|
||||
| StepwiseSupervisedDataset
|
||||
| SyntheticDataset
|
||||
],
|
||||
MinLen(1),
|
||||
]
|
||||
| None
|
||||
|
||||
@@ -296,4 +296,42 @@ class KTODataset(BaseModel):
|
||||
revision: str | None = None
|
||||
|
||||
|
||||
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset
|
||||
class SyntheticDataset(BaseModel):
|
||||
"""Synthetic dataset configuration for benchmarking and testing.
|
||||
|
||||
Generates datasets with configurable sequence length, dataset size, and token ID
|
||||
ranges. Useful for benchmarking memory usage and speed by sequence length, and for
|
||||
validating weighted dataset mixes.
|
||||
"""
|
||||
|
||||
path: Literal["synthetic"] = "synthetic"
|
||||
type: Literal["_synthetic"] = "_synthetic"
|
||||
length: int = Field(
|
||||
default=1000,
|
||||
json_schema_extra={"description": "Number of rows to generate"},
|
||||
)
|
||||
sequence_length: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Sequence length per row (defaults to sequence_len from config)"
|
||||
},
|
||||
)
|
||||
min_input_id: int = Field(
|
||||
default=100,
|
||||
json_schema_extra={"description": "Minimum token ID for generation"},
|
||||
)
|
||||
max_input_id: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Maximum token ID for generation (defaults to tokenizer vocab_size)"
|
||||
},
|
||||
)
|
||||
seed: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Random seed for reproducibility"},
|
||||
)
|
||||
|
||||
|
||||
DatasetConfig = (
|
||||
SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset | SyntheticDataset
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user