use offline for precached stream dataset (#2453)
This commit is contained in:
@@ -4,8 +4,8 @@ Test dataset loading under various conditions.
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from constants import (
|
||||
@@ -15,7 +15,7 @@ from constants import (
|
||||
)
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import PreTrainedTokenizer
|
||||
from utils import enable_hf_offline
|
||||
|
||||
from axolotl.utils.data import load_tokenized_prepared_datasets
|
||||
@@ -23,15 +23,17 @@ from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestDatasetPreparation(unittest.TestCase):
|
||||
class TestDatasetPreparation:
|
||||
"""Test a configured dataloader."""
|
||||
|
||||
@enable_hf_offline
|
||||
def setUp(self) -> None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
||||
# Alpaca dataset.
|
||||
self.dataset = Dataset.from_list(
|
||||
@pytest.fixture
|
||||
def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer:
|
||||
tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)
|
||||
yield tokenizer_huggyllama
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_fixture(self):
|
||||
yield Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
|
||||
@@ -43,7 +45,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
|
||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||
@enable_hf_offline
|
||||
def test_load_hub(self):
|
||||
def test_load_hub(self, tokenizer):
|
||||
"""Core use case. Verify that processing data from the hub works"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
@@ -60,9 +62,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -71,7 +71,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
|
||||
@enable_hf_offline
|
||||
@pytest.mark.skip("datasets bug with local datasets when offline")
|
||||
def test_load_local_hub(self):
|
||||
def test_load_local_hub(self, tokenizer):
|
||||
"""Niche use case. Verify that a local copy of a hub dataset can be loaded"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||
@@ -106,9 +106,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -117,11 +115,11 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_from_save_to_disk(self):
|
||||
def test_load_from_save_to_disk(self, tokenizer, dataset_fixture):
|
||||
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
|
||||
self.dataset.save_to_disk(str(tmp_ds_name))
|
||||
dataset_fixture.save_to_disk(str(tmp_ds_name))
|
||||
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
@@ -137,9 +135,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -147,13 +143,13 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
assert "labels" in dataset.features
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_from_dir_of_parquet(self):
|
||||
def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
|
||||
"""Usual use case. Verify a directory of parquet files can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
||||
tmp_ds_dir.mkdir()
|
||||
tmp_ds_path = tmp_ds_dir / "shard1.parquet"
|
||||
self.dataset.to_parquet(tmp_ds_path)
|
||||
dataset_fixture.to_parquet(tmp_ds_path)
|
||||
|
||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
@@ -174,9 +170,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -184,13 +178,13 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
assert "labels" in dataset.features
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_from_dir_of_json(self):
|
||||
def test_load_from_dir_of_json(self, tokenizer, dataset_fixture):
|
||||
"""Standard use case. Verify a directory of json files can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
||||
tmp_ds_dir.mkdir()
|
||||
tmp_ds_path = tmp_ds_dir / "shard1.json"
|
||||
self.dataset.to_json(tmp_ds_path)
|
||||
dataset_fixture.to_json(tmp_ds_path)
|
||||
|
||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
@@ -211,9 +205,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -221,11 +213,11 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
assert "labels" in dataset.features
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_from_single_parquet(self):
|
||||
def test_load_from_single_parquet(self, tokenizer, dataset_fixture):
|
||||
"""Standard use case. Verify a single parquet file can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet"
|
||||
self.dataset.to_parquet(tmp_ds_path)
|
||||
dataset_fixture.to_parquet(tmp_ds_path)
|
||||
|
||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
@@ -242,9 +234,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -252,11 +242,11 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
assert "labels" in dataset.features
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_from_single_json(self):
|
||||
def test_load_from_single_json(self, tokenizer, dataset_fixture):
|
||||
"""Standard use case. Verify a single json file can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json"
|
||||
self.dataset.to_json(tmp_ds_path)
|
||||
dataset_fixture.to_json(tmp_ds_path)
|
||||
|
||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
@@ -273,9 +263,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -304,7 +292,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
|
||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||
@enable_hf_offline
|
||||
def test_load_hub_with_revision(self):
|
||||
def test_load_hub_with_revision(self, tokenizer):
|
||||
"""Verify that processing data from the hub works with a specific revision"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
@@ -326,9 +314,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -336,7 +322,9 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
assert "labels" in dataset.features
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_hub_with_revision_with_dpo(self):
|
||||
def test_load_hub_with_revision_with_dpo(
|
||||
self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
|
||||
):
|
||||
"""Verify that processing dpo data from the hub works with a specific revision"""
|
||||
|
||||
cfg = DictDefault(
|
||||
@@ -349,14 +337,23 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
# pylint: disable=duplicate-code
|
||||
with patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
) as mock_load_dataset:
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.return_value = (
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
|
||||
)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" in train_dataset.features
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" in train_dataset.features
|
||||
|
||||
@enable_hf_offline
|
||||
@pytest.mark.skip("datasets bug with local datasets when offline")
|
||||
def test_load_local_hub_with_revision(self):
|
||||
def test_load_local_hub_with_revision(self, tokenizer):
|
||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||
@@ -388,9 +385,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -399,7 +394,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
@enable_hf_offline
|
||||
def test_loading_local_dataset_folder(self):
|
||||
def test_loading_local_dataset_folder(self, tokenizer):
|
||||
"""Verify that a dataset downloaded to a local folder can be loaded"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -426,16 +421,10 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user