From c49682132b44613d35e224ecbcb554db07dda85a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 28 Mar 2025 23:39:09 -0400 Subject: [PATCH] use offline for precached stream dataset (#2453) --- .github/workflows/tests.yml | 3 + tests/conftest.py | 44 +++++-- tests/prompt_strategies/conftest.py | 9 +- .../test_chat_templates_advanced.py | 2 + tests/test_datasets.py | 113 ++++++++---------- tests/test_exact_deduplication.py | 75 ++++++++---- tests/test_packed_pretraining.py | 55 +++++---- tests/utils/__init__.py | 2 +- 8 files changed, 179 insertions(+), 124 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4ef8f54f7..632731a2d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -171,6 +171,9 @@ jobs: run: | axolotl --help + - name: Show HF cache + run: huggingface-cli scan-cache + - name: Run tests run: | pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ diff --git a/tests/conftest.py b/tests/conftest.py index 7a42ce428..8cf083290 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,8 +11,10 @@ import time import pytest import requests +from datasets import load_dataset from huggingface_hub import snapshot_download -from utils import disable_hf_offline +from transformers import AutoTokenizer +from utils import disable_hf_offline, enable_hf_offline def retry_on_request_exceptions(max_retries=3, delay=1): @@ -46,7 +48,6 @@ def snapshot_download_w_retry(*args, **kwargs): @pytest.fixture(scope="session", autouse=True) -@disable_hf_offline def download_smollm2_135m_model(): # download the model snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M", repo_type="model") @@ -59,28 +60,24 @@ def download_llama_68m_random_model(): @pytest.fixture(scope="session", autouse=True) -@disable_hf_offline def download_qwen_2_5_half_billion_model(): # download the model snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model") @pytest.fixture(scope="session", autouse=True) -@disable_hf_offline def download_tatsu_lab_alpaca_dataset(): # download the dataset snapshot_download_w_retry("tatsu-lab/alpaca", repo_type="dataset") @pytest.fixture(scope="session", autouse=True) -@disable_hf_offline def download_mhenrichsen_alpaca_2k_dataset(): # download the dataset snapshot_download_w_retry("mhenrichsen/alpaca_2k_test", repo_type="dataset") @pytest.fixture(scope="session", autouse=True) -@disable_hf_offline def download_mhenrichsen_alpaca_2k_w_revision_dataset(): # download the dataset snapshot_download_w_retry( @@ -89,7 +86,6 @@ def download_mhenrichsen_alpaca_2k_w_revision_dataset(): @pytest.fixture(scope="session", autouse=True) -@disable_hf_offline def download_mlabonne_finetome_100k_dataset(): # download the dataset snapshot_download_w_retry("mlabonne/FineTome-100k", repo_type="dataset") @@ -124,6 +120,24 @@ def download_fozzie_alpaca_dpo_dataset(): ) +@pytest.fixture(scope="session") +@disable_hf_offline +def dataset_fozzie_alpaca_dpo_dataset( + download_fozzie_alpaca_dpo_dataset, +): # pylint: disable=unused-argument,redefined-outer-name + return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train") + + +@pytest.fixture(scope="session") +@disable_hf_offline +def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff( + download_fozzie_alpaca_dpo_dataset, +): # pylint: disable=unused-argument,redefined-outer-name + return load_dataset( + "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff" + ) + + @pytest.fixture(scope="session", autouse=True) def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset(): # download the dataset @@ -152,7 +166,6 @@ def download_deepseek_model_fixture(): @pytest.fixture(scope="session", autouse=True) -@disable_hf_offline def download_huggyllama_model_fixture(): # download the tokenizer only snapshot_download_w_retry( @@ -163,7 +176,6 @@ def download_huggyllama_model_fixture(): @pytest.fixture(scope="session", autouse=True) -@disable_hf_offline def download_llama_1b_model_fixture(): # download the tokenizer only snapshot_download_w_retry( @@ -174,7 +186,6 @@ def download_llama_1b_model_fixture(): @pytest.fixture(scope="session", autouse=True) -@disable_hf_offline def download_llama3_8b_model_fixture(): # download the tokenizer only snapshot_download_w_retry( @@ -183,7 +194,6 @@ def download_llama3_8b_model_fixture(): @pytest.fixture(scope="session", autouse=True) -@disable_hf_offline def download_llama3_8b_instruct_model_fixture(): # download the tokenizer only snapshot_download_w_retry( @@ -194,7 +204,6 @@ def download_llama3_8b_instruct_model_fixture(): @pytest.fixture(scope="session", autouse=True) -@disable_hf_offline def download_phi_35_mini_model_fixture(): # download the tokenizer only snapshot_download_w_retry( @@ -263,6 +272,17 @@ def download_llama2_model_fixture(): ) +@pytest.fixture(scope="session", autouse=True) +@enable_hf_offline +def tokenizer_huggyllama( + download_huggyllama_model_fixture, +): # pylint: disable=unused-argument,redefined-outer-name + tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") + tokenizer.pad_token = "" + + return tokenizer + + @pytest.fixture def temp_dir(): # Create a temporary directory diff --git a/tests/prompt_strategies/conftest.py b/tests/prompt_strategies/conftest.py index d4d0c12f8..44914e617 100644 --- a/tests/prompt_strategies/conftest.py +++ b/tests/prompt_strategies/conftest.py @@ -109,7 +109,9 @@ def fixture_toolcalling_dataset(): @pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True) @enable_hf_offline -def fixture_llama3_tokenizer(): +def fixture_llama3_tokenizer( + download_llama3_8b_instruct_model_fixture, +): # pylint: disable=unused-argument,redefined-outer-name tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct") return tokenizer @@ -123,7 +125,10 @@ def fixture_smollm2_tokenizer(): @pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True) -def fixture_mistralv03_tokenizer(): +@enable_hf_offline +def fixture_mistralv03_tokenizer( + download_mlx_mistral_7b_model_fixture, +): # pylint: disable=unused-argument,redefined-outer-name tokenizer = AutoTokenizer.from_pretrained( "mlx-community/Mistral-7B-Instruct-v0.3-4bit" ) diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py index 69031bd65..f316e6ec3 100644 --- a/tests/prompt_strategies/test_chat_templates_advanced.py +++ b/tests/prompt_strategies/test_chat_templates_advanced.py @@ -9,6 +9,7 @@ import pytest from datasets import Dataset from tokenizers import AddedToken from transformers import PreTrainedTokenizer +from utils import enable_hf_offline from axolotl.prompt_strategies.chat_template import ( ChatTemplatePrompter, @@ -101,6 +102,7 @@ class TestChatTemplateConfigurations: return True return False + @enable_hf_offline def test_train_on_inputs_true( self, tokenizer, diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 4a64074d7..71d285497 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -4,8 +4,8 @@ Test dataset loading under various conditions. import shutil import tempfile -import unittest from pathlib import Path +from unittest.mock import patch import pytest from constants import ( @@ -15,7 +15,7 @@ from constants import ( ) from datasets import Dataset from huggingface_hub import snapshot_download -from transformers import AutoTokenizer +from transformers import PreTrainedTokenizer from utils import enable_hf_offline from axolotl.utils.data import load_tokenized_prepared_datasets @@ -23,15 +23,17 @@ from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.dict import DictDefault -class TestDatasetPreparation(unittest.TestCase): +class TestDatasetPreparation: """Test a configured dataloader.""" - @enable_hf_offline - def setUp(self) -> None: - self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") - self.tokenizer.add_special_tokens(SPECIAL_TOKENS) - # Alpaca dataset. - self.dataset = Dataset.from_list( + @pytest.fixture + def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer: + tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS) + yield tokenizer_huggyllama + + @pytest.fixture + def dataset_fixture(self): + yield Dataset.from_list( [ { "instruction": "Evaluate this sentence for spelling and grammar mistakes", @@ -43,7 +45,7 @@ class TestDatasetPreparation(unittest.TestCase): @pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits") @enable_hf_offline - def test_load_hub(self): + def test_load_hub(self, tokenizer): """Core use case. Verify that processing data from the hub works""" with tempfile.TemporaryDirectory() as tmp_dir: prepared_path = Path(tmp_dir) / "prepared" @@ -60,9 +62,7 @@ class TestDatasetPreparation(unittest.TestCase): } ) - dataset, _ = load_tokenized_prepared_datasets( - self.tokenizer, cfg, prepared_path - ) + dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) assert len(dataset) == 2000 assert "input_ids" in dataset.features @@ -71,7 +71,7 @@ class TestDatasetPreparation(unittest.TestCase): @enable_hf_offline @pytest.mark.skip("datasets bug with local datasets when offline") - def test_load_local_hub(self): + def test_load_local_hub(self, tokenizer): """Niche use case. Verify that a local copy of a hub dataset can be loaded""" with tempfile.TemporaryDirectory() as tmp_dir: tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test" @@ -106,9 +106,7 @@ class TestDatasetPreparation(unittest.TestCase): } ) - dataset, _ = load_tokenized_prepared_datasets( - self.tokenizer, cfg, prepared_path - ) + dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) assert len(dataset) == 2000 assert "input_ids" in dataset.features @@ -117,11 +115,11 @@ class TestDatasetPreparation(unittest.TestCase): shutil.rmtree(tmp_ds_path) @enable_hf_offline - def test_load_from_save_to_disk(self): + def test_load_from_save_to_disk(self, tokenizer, dataset_fixture): """Usual use case. Verify datasets saved via `save_to_disk` can be loaded.""" with tempfile.TemporaryDirectory() as tmp_dir: tmp_ds_name = Path(tmp_dir) / "tmp_dataset" - self.dataset.save_to_disk(str(tmp_ds_name)) + dataset_fixture.save_to_disk(str(tmp_ds_name)) prepared_path = Path(tmp_dir) / "prepared" cfg = DictDefault( @@ -137,9 +135,7 @@ class TestDatasetPreparation(unittest.TestCase): } ) - dataset, _ = load_tokenized_prepared_datasets( - self.tokenizer, cfg, prepared_path - ) + dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) assert len(dataset) == 1 assert "input_ids" in dataset.features @@ -147,13 +143,13 @@ class TestDatasetPreparation(unittest.TestCase): assert "labels" in dataset.features @enable_hf_offline - def test_load_from_dir_of_parquet(self): + def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture): """Usual use case. Verify a directory of parquet files can be loaded.""" with tempfile.TemporaryDirectory() as tmp_dir: tmp_ds_dir = Path(tmp_dir) / "tmp_dataset" tmp_ds_dir.mkdir() tmp_ds_path = tmp_ds_dir / "shard1.parquet" - self.dataset.to_parquet(tmp_ds_path) + dataset_fixture.to_parquet(tmp_ds_path) prepared_path: Path = Path(tmp_dir) / "prepared" cfg = DictDefault( @@ -174,9 +170,7 @@ class TestDatasetPreparation(unittest.TestCase): } ) - dataset, _ = load_tokenized_prepared_datasets( - self.tokenizer, cfg, prepared_path - ) + dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) assert len(dataset) == 1 assert "input_ids" in dataset.features @@ -184,13 +178,13 @@ class TestDatasetPreparation(unittest.TestCase): assert "labels" in dataset.features @enable_hf_offline - def test_load_from_dir_of_json(self): + def test_load_from_dir_of_json(self, tokenizer, dataset_fixture): """Standard use case. Verify a directory of json files can be loaded.""" with tempfile.TemporaryDirectory() as tmp_dir: tmp_ds_dir = Path(tmp_dir) / "tmp_dataset" tmp_ds_dir.mkdir() tmp_ds_path = tmp_ds_dir / "shard1.json" - self.dataset.to_json(tmp_ds_path) + dataset_fixture.to_json(tmp_ds_path) prepared_path: Path = Path(tmp_dir) / "prepared" cfg = DictDefault( @@ -211,9 +205,7 @@ class TestDatasetPreparation(unittest.TestCase): } ) - dataset, _ = load_tokenized_prepared_datasets( - self.tokenizer, cfg, prepared_path - ) + dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) assert len(dataset) == 1 assert "input_ids" in dataset.features @@ -221,11 +213,11 @@ class TestDatasetPreparation(unittest.TestCase): assert "labels" in dataset.features @enable_hf_offline - def test_load_from_single_parquet(self): + def test_load_from_single_parquet(self, tokenizer, dataset_fixture): """Standard use case. Verify a single parquet file can be loaded.""" with tempfile.TemporaryDirectory() as tmp_dir: tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet" - self.dataset.to_parquet(tmp_ds_path) + dataset_fixture.to_parquet(tmp_ds_path) prepared_path: Path = Path(tmp_dir) / "prepared" cfg = DictDefault( @@ -242,9 +234,7 @@ class TestDatasetPreparation(unittest.TestCase): } ) - dataset, _ = load_tokenized_prepared_datasets( - self.tokenizer, cfg, prepared_path - ) + dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) assert len(dataset) == 1 assert "input_ids" in dataset.features @@ -252,11 +242,11 @@ class TestDatasetPreparation(unittest.TestCase): assert "labels" in dataset.features @enable_hf_offline - def test_load_from_single_json(self): + def test_load_from_single_json(self, tokenizer, dataset_fixture): """Standard use case. Verify a single json file can be loaded.""" with tempfile.TemporaryDirectory() as tmp_dir: tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json" - self.dataset.to_json(tmp_ds_path) + dataset_fixture.to_json(tmp_ds_path) prepared_path: Path = Path(tmp_dir) / "prepared" cfg = DictDefault( @@ -273,9 +263,7 @@ class TestDatasetPreparation(unittest.TestCase): } ) - dataset, _ = load_tokenized_prepared_datasets( - self.tokenizer, cfg, prepared_path - ) + dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) assert len(dataset) == 1 assert "input_ids" in dataset.features @@ -304,7 +292,7 @@ class TestDatasetPreparation(unittest.TestCase): @pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits") @enable_hf_offline - def test_load_hub_with_revision(self): + def test_load_hub_with_revision(self, tokenizer): """Verify that processing data from the hub works with a specific revision""" with tempfile.TemporaryDirectory() as tmp_dir: prepared_path = Path(tmp_dir) / "prepared" @@ -326,9 +314,7 @@ class TestDatasetPreparation(unittest.TestCase): } ) - dataset, _ = load_tokenized_prepared_datasets( - self.tokenizer, cfg, prepared_path - ) + dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) assert len(dataset) == 2000 assert "input_ids" in dataset.features @@ -336,7 +322,9 @@ class TestDatasetPreparation(unittest.TestCase): assert "labels" in dataset.features @enable_hf_offline - def test_load_hub_with_revision_with_dpo(self): + def test_load_hub_with_revision_with_dpo( + self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff + ): """Verify that processing dpo data from the hub works with a specific revision""" cfg = DictDefault( @@ -349,14 +337,23 @@ class TestDatasetPreparation(unittest.TestCase): } ) - train_dataset, _ = load_prepare_preference_datasets(cfg) + # pylint: disable=duplicate-code + with patch( + "axolotl.utils.data.shared.load_dataset_w_config" + ) as mock_load_dataset: + # Set up the mock to return different values on successive calls + mock_load_dataset.return_value = ( + dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff + ) - assert len(train_dataset) == 1800 - assert "conversation" in train_dataset.features + train_dataset, _ = load_prepare_preference_datasets(cfg) + + assert len(train_dataset) == 1800 + assert "conversation" in train_dataset.features @enable_hf_offline @pytest.mark.skip("datasets bug with local datasets when offline") - def test_load_local_hub_with_revision(self): + def test_load_local_hub_with_revision(self, tokenizer): """Verify that a local copy of a hub dataset can be loaded with a specific revision""" with tempfile.TemporaryDirectory() as tmp_dir: tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test" @@ -388,9 +385,7 @@ class TestDatasetPreparation(unittest.TestCase): } ) - dataset, _ = load_tokenized_prepared_datasets( - self.tokenizer, cfg, prepared_path - ) + dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) assert len(dataset) == 2000 assert "input_ids" in dataset.features @@ -399,7 +394,7 @@ class TestDatasetPreparation(unittest.TestCase): shutil.rmtree(tmp_ds_path) @enable_hf_offline - def test_loading_local_dataset_folder(self): + def test_loading_local_dataset_folder(self, tokenizer): """Verify that a dataset downloaded to a local folder can be loaded""" with tempfile.TemporaryDirectory() as tmp_dir: @@ -426,16 +421,10 @@ class TestDatasetPreparation(unittest.TestCase): } ) - dataset, _ = load_tokenized_prepared_datasets( - self.tokenizer, cfg, prepared_path - ) + dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) assert len(dataset) == 2000 assert "input_ids" in dataset.features assert "attention_mask" in dataset.features assert "labels" in dataset.features shutil.rmtree(tmp_ds_path) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_exact_deduplication.py b/tests/test_exact_deduplication.py index 865ff030c..9549860f7 100644 --- a/tests/test_exact_deduplication.py +++ b/tests/test_exact_deduplication.py @@ -9,9 +9,8 @@ import unittest from unittest.mock import patch import pytest -from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS +from constants import ALPACA_MESSAGES_CONFIG_REVISION from datasets import Dataset -from transformers import AutoTokenizer from utils import enable_hf_offline from axolotl.utils.config import normalize_config @@ -216,13 +215,12 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase): verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset") -class TestDeduplicateRLDataset(unittest.TestCase): +class TestDeduplicateRLDataset: """Test a configured dataloader with deduplication.""" - def setUp(self) -> None: - self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") - self.tokenizer.add_special_tokens(SPECIAL_TOKENS) - self.cfg = DictDefault( + @pytest.fixture + def cfg(self): + fixture = DictDefault( { "tokenizer_config": "huggyllama/llama-7b", "sequence_len": 1024, @@ -235,28 +233,59 @@ class TestDeduplicateRLDataset(unittest.TestCase): ], } ) + yield fixture - @pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits") @enable_hf_offline - def test_load_with_deduplication(self): + def test_load_with_deduplication( + self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama + ): """Verify that loading with deduplication removes duplicates.""" - # Load the dataset using the deduplication setting - train_dataset, _ = load_prepare_preference_datasets(self.cfg) + # pylint: disable=duplicate-code + with ( + patch( + "axolotl.utils.data.shared.load_dataset_w_config" + ) as mock_load_dataset, + patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer, + ): + # Set up the mock to return different values on successive calls + mock_load_dataset.side_effect = [ + dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, + dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, + ] + mock_load_tokenizer.return_value = tokenizer_huggyllama - # Verify that the dataset has been deduplicated - assert len(train_dataset) == 1800, "Dataset was not properly deduplicated" + train_dataset, _ = load_prepare_preference_datasets(cfg) - def test_load_without_deduplication(self): - """Verify that loading without deduplication retains duplicates.""" - self.cfg.dataset_exact_deduplication = False - # Load the dataset without deduplication - train_dataset, _ = load_prepare_preference_datasets(self.cfg) + # Verify that the dataset has been deduplicated + assert len(train_dataset) == 1800, "Dataset was not properly deduplicated" - # Verify that the dataset retains duplicates - assert ( - len(train_dataset) == 1800 * 2 - ), "Dataset deduplication occurred when it should not have" + @enable_hf_offline + def test_load_without_deduplication( + self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama + ): + # pylint: disable=duplicate-code + with ( + patch( + "axolotl.utils.data.shared.load_dataset_w_config" + ) as mock_load_dataset, + patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer, + ): + # Set up the mock to return different values on successive calls + mock_load_dataset.side_effect = [ + dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, + dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, + ] + mock_load_tokenizer.return_value = tokenizer_huggyllama + + cfg.dataset_exact_deduplication = False + # Load the dataset without deduplication + train_dataset, _ = load_prepare_preference_datasets(cfg) + + # Verify that the dataset retains duplicates + assert ( + len(train_dataset) == 1800 * 2 + ), "Dataset deduplication occurred when it should not have" class TestDeduplicateNonRL(unittest.TestCase): @@ -264,8 +293,6 @@ class TestDeduplicateNonRL(unittest.TestCase): @enable_hf_offline def setUp(self) -> None: - self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") - self.tokenizer.add_special_tokens(SPECIAL_TOKENS) self.cfg_1 = DictDefault( { "base_model": "huggyllama/llama-7b", diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py index bd8d81dcc..f783af9cc 100644 --- a/tests/test_packed_pretraining.py +++ b/tests/test_packed_pretraining.py @@ -1,38 +1,50 @@ """Module for testing streaming dataset sequence packing""" import functools -import unittest +import random +import string import pytest import torch -from datasets import load_dataset +from datasets import IterableDataset from torch.utils.data import DataLoader -from transformers import AutoTokenizer -from utils import disable_hf_offline, enable_hf_offline from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset from axolotl.utils.dict import DictDefault -class TestPretrainingPacking(unittest.TestCase): +class TestPretrainingPacking: """ Test class for packing streaming dataset sequences """ - @enable_hf_offline - def setUp(self) -> None: - # pylint: disable=duplicate-code - self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") - self.tokenizer.pad_token = "" + @pytest.fixture + def random_text(self): + # seed with random.seed(0) for reproducibility + random.seed(0) + + # generate 20 rows of random text with "words" of between 2 and 10 characters and + # between 400 to 1200 characters per line + data = [ + "".join(random.choices(string.ascii_lowercase, k=random.randint(2, 10))) + for _ in range(20) + ] + [ + " ".join( + random.choices(string.ascii_lowercase, k=random.randint(400, 1200)) + ) + for _ in range(20) + ] + + # Create an IterableDataset + def generator(): + for text in data: + yield {"text": text} + + return IterableDataset.from_generator(generator) @pytest.mark.flaky(retries=1, delay=5) - @disable_hf_offline - def test_packing_stream_dataset(self): - # pylint: disable=duplicate-code - dataset = load_dataset( - "winglian/tiny-shakespeare", - streaming=True, - )["train"] + def test_packing_stream_dataset(self, tokenizer_huggyllama, random_text): + dataset = random_text cfg = DictDefault( { @@ -55,15 +67,16 @@ class TestPretrainingPacking(unittest.TestCase): ds_wrapper_partial = functools.partial( get_dataset_wrapper, cfg.pretraining_dataset[0], - self.tokenizer, + tokenizer_huggyllama, cfg, cfg.pretraining_dataset[0]["type"] or "pretrain", ) + # pylint: disable=duplicate-code original_bsz = cfg.micro_batch_size train_dataset = wrap_pretraining_dataset( dataset, - self.tokenizer, + tokenizer_huggyllama, cfg, ds_wrapper_partial, max_tokens=cfg.sequence_len, @@ -96,7 +109,3 @@ class TestPretrainingPacking(unittest.TestCase): # [1, original_bsz * cfg.sequence_len] # ) idx += 1 - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index bc4920671..0ce878577 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -19,7 +19,7 @@ def reload_modules(hf_hub_offline): importlib.reload(huggingface_hub.constants) huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline importlib.reload(datasets.config) - datasets.config.HF_HUB_OFFLINE = hf_hub_offline + setattr(datasets.config, "HF_HUB_OFFLINE", hf_hub_offline) reset_sessions()