fix: loading locally downloaded dataset (#2056) [skip ci]
This commit is contained in:
@@ -350,7 +350,15 @@ def load_tokenized_prepared_datasets(
|
||||
split=None,
|
||||
)
|
||||
else:
|
||||
ds = load_from_disk(config_dataset.path)
|
||||
try:
|
||||
ds = load_from_disk(config_dataset.path)
|
||||
except FileNotFoundError:
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
|
||||
|
||||
@@ -371,44 +371,79 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
def test_load_local_hub_with_revision(self):
|
||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir2:
|
||||
tmp_ds_path = Path(tmp_dir2) / "mhenrichsen/alpaca_2k_test"
|
||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||
snapshot_download(
|
||||
repo_id="mhenrichsen/alpaca_2k_test",
|
||||
repo_type="dataset",
|
||||
local_dir=tmp_ds_path,
|
||||
revision="d05c1cb",
|
||||
)
|
||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||
snapshot_download(
|
||||
repo_id="mhenrichsen/alpaca_2k_test",
|
||||
repo_type="dataset",
|
||||
local_dir=tmp_ds_path,
|
||||
revision="d05c1cb",
|
||||
)
|
||||
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"ds_type": "parquet",
|
||||
"type": "alpaca",
|
||||
"data_files": [
|
||||
f"{tmp_ds_path}/alpaca_2000.parquet",
|
||||
],
|
||||
"revision": "d05c1cb",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"ds_type": "parquet",
|
||||
"type": "alpaca",
|
||||
"data_files": [
|
||||
f"{tmp_ds_path}/alpaca_2000.parquet",
|
||||
],
|
||||
"revision": "d05c1cb",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
def test_loading_local_dataset_folder(self):
|
||||
"""Verify that a dataset downloaded to a local folder can be loaded"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||
snapshot_download(
|
||||
repo_id="mhenrichsen/alpaca_2k_test",
|
||||
repo_type="dataset",
|
||||
local_dir=tmp_ds_path,
|
||||
)
|
||||
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
"datasets": [
|
||||
{
|
||||
"path": str(tmp_ds_path),
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user