synthetic datasets for benchmarking and testing (#3518) [skip ci]
* synthetic datasets for benchmarking and testing * fix synthetic dataset parse from config and add tests * use type=_synthetic
This commit is contained in:
@@ -13,6 +13,7 @@ from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
||||
from axolotl.utils.schemas.datasets import SFTDataset
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
warnings.filterwarnings("error")
|
||||
@@ -1731,3 +1732,52 @@ class TestDataloaderValidation(BaseValidation):
|
||||
assert new_cfg.dataloader_num_workers == 8
|
||||
assert new_cfg.dataloader_pin_memory is True
|
||||
assert new_cfg.dataloader_prefetch_factor == 256
|
||||
|
||||
|
||||
class TestSyntheticDatasetValidation(BaseValidation):
|
||||
"""
|
||||
Tests for synthetic dataset config validation
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_cfg(minimal_cfg, datasets):
|
||||
raw = dict(minimal_cfg)
|
||||
raw["datasets"] = datasets
|
||||
return DictDefault(raw)
|
||||
|
||||
def test_synthetic_dict_config_validates(self, minimal_cfg):
|
||||
"""Synthetic dataset passed as a raw dict should not raise."""
|
||||
cfg = self._make_cfg(
|
||||
minimal_cfg,
|
||||
[
|
||||
{
|
||||
"path": "synthetic",
|
||||
"type": "_synthetic",
|
||||
"length": 100,
|
||||
"sequence_length": 64,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.datasets[0]["path"] == "synthetic"
|
||||
|
||||
def test_synthetic_already_sft_does_not_crash(self, minimal_cfg):
|
||||
"""Synthetic dataset already parsed as SFTDataset should not raise AttributeError."""
|
||||
sft = SFTDataset(path="synthetic", type="_synthetic")
|
||||
cfg = self._make_cfg(minimal_cfg, [sft])
|
||||
|
||||
# Before the fix, this raised:
|
||||
# AttributeError: 'SFTDataset' object has no attribute 'get'
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.datasets[0]["path"] == "synthetic"
|
||||
|
||||
def test_non_synthetic_sft_validates(self, minimal_cfg):
|
||||
"""A regular SFT dataset should validate without being treated as synthetic."""
|
||||
cfg = self._make_cfg(
|
||||
minimal_cfg,
|
||||
[{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
|
||||
)
|
||||
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.datasets[0]["path"] == "mhenrichsen/alpaca_2k_test"
|
||||
|
||||
Reference in New Issue
Block a user