Support loading datasets saved via save_to_disk (#1432)
* Support loading datasetes saved via save_to_disk * Adding comprehensive unittests * Fix dataset tests due to new hash changes
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Module containing data utilities"""
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import logging
|
||||
@@ -223,7 +224,7 @@ def load_tokenized_prepared_datasets(
|
||||
token=use_auth_token,
|
||||
)
|
||||
ds_from_hub = True
|
||||
except (FileNotFoundError, ConnectionError, HFValidationError):
|
||||
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
||||
pass
|
||||
|
||||
ds_from_cloud = False
|
||||
@@ -290,14 +291,17 @@ def load_tokenized_prepared_datasets(
|
||||
local_path = Path(config_dataset.path)
|
||||
if local_path.exists():
|
||||
if local_path.is_dir():
|
||||
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.data_files,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
if config_dataset.data_files:
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.data_files,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
else:
|
||||
ds = load_from_disk(config_dataset.path)
|
||||
elif local_path.is_file():
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
|
||||
|
||||
272
tests/test_datasets.py
Normal file
272
tests/test_datasets.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Test dataset loading under various conditions.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.utils.data import load_tokenized_prepared_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestDatasetPreparation(unittest.TestCase):
|
||||
"""Test a configured dataloader."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
)
|
||||
# Alpaca dataset.
|
||||
self.dataset = Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
|
||||
"input": "He finnished his meal and left the resturant",
|
||||
"output": "He finished his meal and left the restaurant.",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def test_load_hub(self):
|
||||
"""Core use case. Verify that processing data from the hub works"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
|
||||
def test_load_local_hub(self):
|
||||
"""Niche use case. Verify that a local copy of a hub dataset can be loaded"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
|
||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||
snapshot_download(
|
||||
repo_id="mhenrichsen/alpaca_2k_test",
|
||||
repo_type="dataset",
|
||||
local_dir=tmp_ds_path,
|
||||
)
|
||||
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
# Right now a local copy that doesn't fully conform to a dataset
|
||||
# must list data_files and ds_type otherwise the loader won't know
|
||||
# how to load it.
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"ds_type": "parquet",
|
||||
"type": "alpaca",
|
||||
"data_files": [
|
||||
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
def test_load_from_save_to_disk(self):
|
||||
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
|
||||
self.dataset.save_to_disk(tmp_ds_name)
|
||||
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 256,
|
||||
"datasets": [
|
||||
{
|
||||
"path": str(tmp_ds_name),
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
|
||||
def test_load_from_dir_of_parquet(self):
|
||||
"""Usual use case. Verify a directory of parquet files can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
||||
tmp_ds_dir.mkdir()
|
||||
tmp_ds_path = tmp_ds_dir / "shard1.parquet"
|
||||
self.dataset.to_parquet(tmp_ds_path)
|
||||
|
||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 256,
|
||||
"datasets": [
|
||||
{
|
||||
"path": str(tmp_ds_dir),
|
||||
"ds_type": "parquet",
|
||||
"name": "test_data",
|
||||
"data_files": [
|
||||
str(tmp_ds_path),
|
||||
],
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
|
||||
def test_load_from_dir_of_json(self):
|
||||
"""Standard use case. Verify a directory of json files can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
||||
tmp_ds_dir.mkdir()
|
||||
tmp_ds_path = tmp_ds_dir / "shard1.json"
|
||||
self.dataset.to_json(tmp_ds_path)
|
||||
|
||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 256,
|
||||
"datasets": [
|
||||
{
|
||||
"path": str(tmp_ds_dir),
|
||||
"ds_type": "json",
|
||||
"name": "test_data",
|
||||
"data_files": [
|
||||
str(tmp_ds_path),
|
||||
],
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
|
||||
def test_load_from_single_parquet(self):
|
||||
"""Standard use case. Verify a single parquet file can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet"
|
||||
self.dataset.to_parquet(tmp_ds_path)
|
||||
|
||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 256,
|
||||
"datasets": [
|
||||
{
|
||||
"path": str(tmp_ds_path),
|
||||
"name": "test_data",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
|
||||
def test_load_from_single_json(self):
|
||||
"""Standard use case. Verify a single json file can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json"
|
||||
self.dataset.to_json(tmp_ds_path)
|
||||
|
||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 256,
|
||||
"datasets": [
|
||||
{
|
||||
"path": str(tmp_ds_path),
|
||||
"name": "test_data",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user