synthetic datasets for benchmarking and testing (#3518) [skip ci]
* synthetic datasets for benchmarking and testing * fix synthetic dataset parse from config and add tests * use type=_synthetic
This commit is contained in:
@@ -13,6 +13,7 @@ from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
||||
from axolotl.utils.schemas.datasets import SFTDataset
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
warnings.filterwarnings("error")
|
||||
@@ -1731,3 +1732,52 @@ class TestDataloaderValidation(BaseValidation):
|
||||
assert new_cfg.dataloader_num_workers == 8
|
||||
assert new_cfg.dataloader_pin_memory is True
|
||||
assert new_cfg.dataloader_prefetch_factor == 256
|
||||
|
||||
|
||||
class TestSyntheticDatasetValidation(BaseValidation):
|
||||
"""
|
||||
Tests for synthetic dataset config validation
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_cfg(minimal_cfg, datasets):
|
||||
raw = dict(minimal_cfg)
|
||||
raw["datasets"] = datasets
|
||||
return DictDefault(raw)
|
||||
|
||||
def test_synthetic_dict_config_validates(self, minimal_cfg):
|
||||
"""Synthetic dataset passed as a raw dict should not raise."""
|
||||
cfg = self._make_cfg(
|
||||
minimal_cfg,
|
||||
[
|
||||
{
|
||||
"path": "synthetic",
|
||||
"type": "_synthetic",
|
||||
"length": 100,
|
||||
"sequence_length": 64,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.datasets[0]["path"] == "synthetic"
|
||||
|
||||
def test_synthetic_already_sft_does_not_crash(self, minimal_cfg):
|
||||
"""Synthetic dataset already parsed as SFTDataset should not raise AttributeError."""
|
||||
sft = SFTDataset(path="synthetic", type="_synthetic")
|
||||
cfg = self._make_cfg(minimal_cfg, [sft])
|
||||
|
||||
# Before the fix, this raised:
|
||||
# AttributeError: 'SFTDataset' object has no attribute 'get'
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.datasets[0]["path"] == "synthetic"
|
||||
|
||||
def test_non_synthetic_sft_validates(self, minimal_cfg):
|
||||
"""A regular SFT dataset should validate without being treated as synthetic."""
|
||||
cfg = self._make_cfg(
|
||||
minimal_cfg,
|
||||
[{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
|
||||
)
|
||||
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.datasets[0]["path"] == "mhenrichsen/alpaca_2k_test"
|
||||
|
||||
125
tests/prompt_strategies/test_synthetic.py
Normal file
125
tests/prompt_strategies/test_synthetic.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Tests for the synthetic dataset generator."""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from axolotl.prompt_strategies._synthetic import SyntheticDatasetStrategy, load
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestSyntheticDatasetStrategy(unittest.TestCase):
|
||||
def test_generates_correct_shape(self):
|
||||
strategy = SyntheticDatasetStrategy(
|
||||
sequence_length=128,
|
||||
length=50,
|
||||
min_input_id=1,
|
||||
max_input_id=1000,
|
||||
seed=42,
|
||||
)
|
||||
dummy = Dataset.from_dict({"text": [""]})
|
||||
result = strategy.wrap_dataset(dummy)
|
||||
|
||||
assert len(result) == 50
|
||||
assert len(result[0]["input_ids"]) == 128
|
||||
assert len(result[0]["attention_mask"]) == 128
|
||||
assert len(result[0]["labels"]) == 128
|
||||
|
||||
def test_attention_mask_all_ones(self):
|
||||
strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0)
|
||||
dummy = Dataset.from_dict({"text": [""]})
|
||||
result = strategy.wrap_dataset(dummy)
|
||||
|
||||
for row in result:
|
||||
assert all(v == 1 for v in row["attention_mask"])
|
||||
|
||||
def test_labels_equal_input_ids(self):
|
||||
strategy = SyntheticDatasetStrategy(sequence_length=64, length=10, seed=0)
|
||||
dummy = Dataset.from_dict({"text": [""]})
|
||||
result = strategy.wrap_dataset(dummy)
|
||||
|
||||
for row in result:
|
||||
assert row["input_ids"] == row["labels"]
|
||||
|
||||
def test_input_id_range(self):
|
||||
strategy = SyntheticDatasetStrategy(
|
||||
sequence_length=64,
|
||||
length=100,
|
||||
min_input_id=500,
|
||||
max_input_id=600,
|
||||
seed=42,
|
||||
)
|
||||
dummy = Dataset.from_dict({"text": [""]})
|
||||
result = strategy.wrap_dataset(dummy)
|
||||
|
||||
for row in result:
|
||||
for token_id in row["input_ids"]:
|
||||
assert 500 <= token_id < 600
|
||||
|
||||
def test_seed_reproducibility(self):
|
||||
kwargs = dict(
|
||||
sequence_length=64, length=20, min_input_id=1, max_input_id=1000, seed=123
|
||||
)
|
||||
dummy = Dataset.from_dict({"text": [""]})
|
||||
|
||||
result1 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy)
|
||||
result2 = SyntheticDatasetStrategy(**kwargs).wrap_dataset(dummy)
|
||||
|
||||
for r1, r2 in zip(result1, result2, strict=True):
|
||||
assert r1["input_ids"] == r2["input_ids"]
|
||||
|
||||
def test_different_seeds_differ(self):
|
||||
common = dict(sequence_length=64, length=20, min_input_id=1, max_input_id=1000)
|
||||
dummy = Dataset.from_dict({"text": [""]})
|
||||
|
||||
result1 = SyntheticDatasetStrategy(seed=1, **common).wrap_dataset(dummy)
|
||||
result2 = SyntheticDatasetStrategy(seed=2, **common).wrap_dataset(dummy)
|
||||
|
||||
any_different = any(
|
||||
r1["input_ids"] != r2["input_ids"]
|
||||
for r1, r2 in zip(result1, result2, strict=True)
|
||||
)
|
||||
assert any_different
|
||||
|
||||
def test_load_function_with_ds_cfg(self):
|
||||
tokenizer = MagicMock()
|
||||
tokenizer.vocab_size = 32000
|
||||
cfg = DictDefault({"sequence_len": 512, "train_on_inputs": False})
|
||||
ds_cfg = {
|
||||
"sequence_length": 256,
|
||||
"length": 5,
|
||||
"min_input_id": 10,
|
||||
"max_input_id": 100,
|
||||
"seed": 0,
|
||||
}
|
||||
|
||||
strategy = load(tokenizer, cfg, ds_cfg=ds_cfg)
|
||||
assert isinstance(strategy, SyntheticDatasetStrategy)
|
||||
assert strategy.sequence_length == 256
|
||||
assert strategy.length == 5
|
||||
assert strategy.min_input_id == 10
|
||||
assert strategy.max_input_id == 100
|
||||
|
||||
def test_load_defaults_from_cfg(self):
|
||||
tokenizer = MagicMock()
|
||||
tokenizer.vocab_size = 32000
|
||||
cfg = DictDefault({"sequence_len": 1024, "train_on_inputs": False})
|
||||
|
||||
strategy = load(tokenizer, cfg, ds_cfg={})
|
||||
assert strategy.sequence_length == 1024
|
||||
assert strategy.max_input_id == 32000
|
||||
assert strategy.length == 1000
|
||||
|
||||
def test_load_with_no_ds_cfg(self):
|
||||
tokenizer = MagicMock()
|
||||
tokenizer.vocab_size = 50000
|
||||
cfg = DictDefault({"sequence_len": 2048, "train_on_inputs": False})
|
||||
|
||||
strategy = load(tokenizer, cfg)
|
||||
assert strategy.sequence_length == 2048
|
||||
assert strategy.max_input_id == 50000
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user