use offline for precached stream dataset (#2453)
This commit is contained in:
3
.github/workflows/tests.yml
vendored
3
.github/workflows/tests.yml
vendored
@@ -171,6 +171,9 @@ jobs:
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Show HF cache
|
||||
run: huggingface-cli scan-cache
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
|
||||
@@ -11,8 +11,10 @@ import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from utils import disable_hf_offline
|
||||
from transformers import AutoTokenizer
|
||||
from utils import disable_hf_offline, enable_hf_offline
|
||||
|
||||
|
||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||
@@ -46,7 +48,6 @@ def snapshot_download_w_retry(*args, **kwargs):
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@disable_hf_offline
|
||||
def download_smollm2_135m_model():
|
||||
# download the model
|
||||
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M", repo_type="model")
|
||||
@@ -59,28 +60,24 @@ def download_llama_68m_random_model():
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@disable_hf_offline
|
||||
def download_qwen_2_5_half_billion_model():
|
||||
# download the model
|
||||
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@disable_hf_offline
|
||||
def download_tatsu_lab_alpaca_dataset():
|
||||
# download the dataset
|
||||
snapshot_download_w_retry("tatsu-lab/alpaca", repo_type="dataset")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@disable_hf_offline
|
||||
def download_mhenrichsen_alpaca_2k_dataset():
|
||||
# download the dataset
|
||||
snapshot_download_w_retry("mhenrichsen/alpaca_2k_test", repo_type="dataset")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@disable_hf_offline
|
||||
def download_mhenrichsen_alpaca_2k_w_revision_dataset():
|
||||
# download the dataset
|
||||
snapshot_download_w_retry(
|
||||
@@ -89,7 +86,6 @@ def download_mhenrichsen_alpaca_2k_w_revision_dataset():
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@disable_hf_offline
|
||||
def download_mlabonne_finetome_100k_dataset():
|
||||
# download the dataset
|
||||
snapshot_download_w_retry("mlabonne/FineTome-100k", repo_type="dataset")
|
||||
@@ -124,6 +120,24 @@ def download_fozzie_alpaca_dpo_dataset():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@disable_hf_offline
|
||||
def dataset_fozzie_alpaca_dpo_dataset(
|
||||
download_fozzie_alpaca_dpo_dataset,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@disable_hf_offline
|
||||
def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
|
||||
download_fozzie_alpaca_dpo_dataset,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
return load_dataset(
|
||||
"fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
||||
# download the dataset
|
||||
@@ -152,7 +166,6 @@ def download_deepseek_model_fixture():
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@disable_hf_offline
|
||||
def download_huggyllama_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
@@ -163,7 +176,6 @@ def download_huggyllama_model_fixture():
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@disable_hf_offline
|
||||
def download_llama_1b_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
@@ -174,7 +186,6 @@ def download_llama_1b_model_fixture():
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@disable_hf_offline
|
||||
def download_llama3_8b_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
@@ -183,7 +194,6 @@ def download_llama3_8b_model_fixture():
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@disable_hf_offline
|
||||
def download_llama3_8b_instruct_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
@@ -194,7 +204,6 @@ def download_llama3_8b_instruct_model_fixture():
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@disable_hf_offline
|
||||
def download_phi_35_mini_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
@@ -263,6 +272,17 @@ def download_llama2_model_fixture():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@enable_hf_offline
|
||||
def tokenizer_huggyllama(
|
||||
download_huggyllama_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
tokenizer.pad_token = "</s>"
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
# Create a temporary directory
|
||||
|
||||
@@ -109,7 +109,9 @@ def fixture_toolcalling_dataset():
|
||||
|
||||
@pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True)
|
||||
@enable_hf_offline
|
||||
def fixture_llama3_tokenizer():
|
||||
def fixture_llama3_tokenizer(
|
||||
download_llama3_8b_instruct_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
||||
|
||||
return tokenizer
|
||||
@@ -123,7 +125,10 @@ def fixture_smollm2_tokenizer():
|
||||
|
||||
|
||||
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
|
||||
def fixture_mistralv03_tokenizer():
|
||||
@enable_hf_offline
|
||||
def fixture_mistralv03_tokenizer(
|
||||
download_mlx_mistral_7b_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"mlx-community/Mistral-7B-Instruct-v0.3-4bit"
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ import pytest
|
||||
from datasets import Dataset
|
||||
from tokenizers import AddedToken
|
||||
from transformers import PreTrainedTokenizer
|
||||
from utils import enable_hf_offline
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import (
|
||||
ChatTemplatePrompter,
|
||||
@@ -101,6 +102,7 @@ class TestChatTemplateConfigurations:
|
||||
return True
|
||||
return False
|
||||
|
||||
@enable_hf_offline
|
||||
def test_train_on_inputs_true(
|
||||
self,
|
||||
tokenizer,
|
||||
|
||||
@@ -4,8 +4,8 @@ Test dataset loading under various conditions.
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from constants import (
|
||||
@@ -15,7 +15,7 @@ from constants import (
|
||||
)
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import PreTrainedTokenizer
|
||||
from utils import enable_hf_offline
|
||||
|
||||
from axolotl.utils.data import load_tokenized_prepared_datasets
|
||||
@@ -23,15 +23,17 @@ from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestDatasetPreparation(unittest.TestCase):
|
||||
class TestDatasetPreparation:
|
||||
"""Test a configured dataloader."""
|
||||
|
||||
@enable_hf_offline
|
||||
def setUp(self) -> None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
||||
# Alpaca dataset.
|
||||
self.dataset = Dataset.from_list(
|
||||
@pytest.fixture
|
||||
def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer:
|
||||
tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)
|
||||
yield tokenizer_huggyllama
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_fixture(self):
|
||||
yield Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
|
||||
@@ -43,7 +45,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
|
||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||
@enable_hf_offline
|
||||
def test_load_hub(self):
|
||||
def test_load_hub(self, tokenizer):
|
||||
"""Core use case. Verify that processing data from the hub works"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
@@ -60,9 +62,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -71,7 +71,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
|
||||
@enable_hf_offline
|
||||
@pytest.mark.skip("datasets bug with local datasets when offline")
|
||||
def test_load_local_hub(self):
|
||||
def test_load_local_hub(self, tokenizer):
|
||||
"""Niche use case. Verify that a local copy of a hub dataset can be loaded"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||
@@ -106,9 +106,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -117,11 +115,11 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_from_save_to_disk(self):
|
||||
def test_load_from_save_to_disk(self, tokenizer, dataset_fixture):
|
||||
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
|
||||
self.dataset.save_to_disk(str(tmp_ds_name))
|
||||
dataset_fixture.save_to_disk(str(tmp_ds_name))
|
||||
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
@@ -137,9 +135,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -147,13 +143,13 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
assert "labels" in dataset.features
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_from_dir_of_parquet(self):
|
||||
def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
|
||||
"""Usual use case. Verify a directory of parquet files can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
||||
tmp_ds_dir.mkdir()
|
||||
tmp_ds_path = tmp_ds_dir / "shard1.parquet"
|
||||
self.dataset.to_parquet(tmp_ds_path)
|
||||
dataset_fixture.to_parquet(tmp_ds_path)
|
||||
|
||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
@@ -174,9 +170,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -184,13 +178,13 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
assert "labels" in dataset.features
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_from_dir_of_json(self):
|
||||
def test_load_from_dir_of_json(self, tokenizer, dataset_fixture):
|
||||
"""Standard use case. Verify a directory of json files can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
||||
tmp_ds_dir.mkdir()
|
||||
tmp_ds_path = tmp_ds_dir / "shard1.json"
|
||||
self.dataset.to_json(tmp_ds_path)
|
||||
dataset_fixture.to_json(tmp_ds_path)
|
||||
|
||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
@@ -211,9 +205,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -221,11 +213,11 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
assert "labels" in dataset.features
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_from_single_parquet(self):
|
||||
def test_load_from_single_parquet(self, tokenizer, dataset_fixture):
|
||||
"""Standard use case. Verify a single parquet file can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet"
|
||||
self.dataset.to_parquet(tmp_ds_path)
|
||||
dataset_fixture.to_parquet(tmp_ds_path)
|
||||
|
||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
@@ -242,9 +234,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -252,11 +242,11 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
assert "labels" in dataset.features
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_from_single_json(self):
|
||||
def test_load_from_single_json(self, tokenizer, dataset_fixture):
|
||||
"""Standard use case. Verify a single json file can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json"
|
||||
self.dataset.to_json(tmp_ds_path)
|
||||
dataset_fixture.to_json(tmp_ds_path)
|
||||
|
||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
@@ -273,9 +263,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -304,7 +292,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
|
||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||
@enable_hf_offline
|
||||
def test_load_hub_with_revision(self):
|
||||
def test_load_hub_with_revision(self, tokenizer):
|
||||
"""Verify that processing data from the hub works with a specific revision"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
@@ -326,9 +314,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -336,7 +322,9 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
assert "labels" in dataset.features
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_hub_with_revision_with_dpo(self):
|
||||
def test_load_hub_with_revision_with_dpo(
|
||||
self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
|
||||
):
|
||||
"""Verify that processing dpo data from the hub works with a specific revision"""
|
||||
|
||||
cfg = DictDefault(
|
||||
@@ -349,14 +337,23 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
# pylint: disable=duplicate-code
|
||||
with patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
) as mock_load_dataset:
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.return_value = (
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
|
||||
)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" in train_dataset.features
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" in train_dataset.features
|
||||
|
||||
@enable_hf_offline
|
||||
@pytest.mark.skip("datasets bug with local datasets when offline")
|
||||
def test_load_local_hub_with_revision(self):
|
||||
def test_load_local_hub_with_revision(self, tokenizer):
|
||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||
@@ -388,9 +385,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -399,7 +394,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
@enable_hf_offline
|
||||
def test_loading_local_dataset_folder(self):
|
||||
def test_loading_local_dataset_folder(self, tokenizer):
|
||||
"""Verify that a dataset downloaded to a local folder can be loaded"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -426,16 +421,10 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -9,9 +9,8 @@ import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
|
||||
from constants import ALPACA_MESSAGES_CONFIG_REVISION
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
from utils import enable_hf_offline
|
||||
|
||||
from axolotl.utils.config import normalize_config
|
||||
@@ -216,13 +215,12 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
|
||||
|
||||
|
||||
class TestDeduplicateRLDataset(unittest.TestCase):
|
||||
class TestDeduplicateRLDataset:
|
||||
"""Test a configured dataloader with deduplication."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
||||
self.cfg = DictDefault(
|
||||
@pytest.fixture
|
||||
def cfg(self):
|
||||
fixture = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
@@ -235,28 +233,59 @@ class TestDeduplicateRLDataset(unittest.TestCase):
|
||||
],
|
||||
}
|
||||
)
|
||||
yield fixture
|
||||
|
||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||
@enable_hf_offline
|
||||
def test_load_with_deduplication(self):
|
||||
def test_load_with_deduplication(
|
||||
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
|
||||
):
|
||||
"""Verify that loading with deduplication removes duplicates."""
|
||||
|
||||
# Load the dataset using the deduplication setting
|
||||
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
) as mock_load_dataset,
|
||||
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.side_effect = [
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
]
|
||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||
|
||||
# Verify that the dataset has been deduplicated
|
||||
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
|
||||
def test_load_without_deduplication(self):
|
||||
"""Verify that loading without deduplication retains duplicates."""
|
||||
self.cfg.dataset_exact_deduplication = False
|
||||
# Load the dataset without deduplication
|
||||
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
|
||||
# Verify that the dataset has been deduplicated
|
||||
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
||||
|
||||
# Verify that the dataset retains duplicates
|
||||
assert (
|
||||
len(train_dataset) == 1800 * 2
|
||||
), "Dataset deduplication occurred when it should not have"
|
||||
@enable_hf_offline
|
||||
def test_load_without_deduplication(
|
||||
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
) as mock_load_dataset,
|
||||
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.side_effect = [
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
]
|
||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||
|
||||
cfg.dataset_exact_deduplication = False
|
||||
# Load the dataset without deduplication
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
|
||||
# Verify that the dataset retains duplicates
|
||||
assert (
|
||||
len(train_dataset) == 1800 * 2
|
||||
), "Dataset deduplication occurred when it should not have"
|
||||
|
||||
|
||||
class TestDeduplicateNonRL(unittest.TestCase):
|
||||
@@ -264,8 +293,6 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
|
||||
@enable_hf_offline
|
||||
def setUp(self) -> None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
||||
self.cfg_1 = DictDefault(
|
||||
{
|
||||
"base_model": "huggyllama/llama-7b",
|
||||
|
||||
@@ -1,38 +1,50 @@
|
||||
"""Module for testing streaming dataset sequence packing"""
|
||||
|
||||
import functools
|
||||
import unittest
|
||||
import random
|
||||
import string
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from datasets import IterableDataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
from utils import disable_hf_offline, enable_hf_offline
|
||||
|
||||
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestPretrainingPacking(unittest.TestCase):
|
||||
class TestPretrainingPacking:
|
||||
"""
|
||||
Test class for packing streaming dataset sequences
|
||||
"""
|
||||
|
||||
@enable_hf_offline
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.pad_token = "</s>"
|
||||
@pytest.fixture
|
||||
def random_text(self):
|
||||
# seed with random.seed(0) for reproducibility
|
||||
random.seed(0)
|
||||
|
||||
# generate 20 rows of random text with "words" of between 2 and 10 characters and
|
||||
# between 400 to 1200 characters per line
|
||||
data = [
|
||||
"".join(random.choices(string.ascii_lowercase, k=random.randint(2, 10)))
|
||||
for _ in range(20)
|
||||
] + [
|
||||
" ".join(
|
||||
random.choices(string.ascii_lowercase, k=random.randint(400, 1200))
|
||||
)
|
||||
for _ in range(20)
|
||||
]
|
||||
|
||||
# Create an IterableDataset
|
||||
def generator():
|
||||
for text in data:
|
||||
yield {"text": text}
|
||||
|
||||
return IterableDataset.from_generator(generator)
|
||||
|
||||
@pytest.mark.flaky(retries=1, delay=5)
|
||||
@disable_hf_offline
|
||||
def test_packing_stream_dataset(self):
|
||||
# pylint: disable=duplicate-code
|
||||
dataset = load_dataset(
|
||||
"winglian/tiny-shakespeare",
|
||||
streaming=True,
|
||||
)["train"]
|
||||
def test_packing_stream_dataset(self, tokenizer_huggyllama, random_text):
|
||||
dataset = random_text
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -55,15 +67,16 @@ class TestPretrainingPacking(unittest.TestCase):
|
||||
ds_wrapper_partial = functools.partial(
|
||||
get_dataset_wrapper,
|
||||
cfg.pretraining_dataset[0],
|
||||
self.tokenizer,
|
||||
tokenizer_huggyllama,
|
||||
cfg,
|
||||
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
||||
)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
original_bsz = cfg.micro_batch_size
|
||||
train_dataset = wrap_pretraining_dataset(
|
||||
dataset,
|
||||
self.tokenizer,
|
||||
tokenizer_huggyllama,
|
||||
cfg,
|
||||
ds_wrapper_partial,
|
||||
max_tokens=cfg.sequence_len,
|
||||
@@ -96,7 +109,3 @@ class TestPretrainingPacking(unittest.TestCase):
|
||||
# [1, original_bsz * cfg.sequence_len]
|
||||
# )
|
||||
idx += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -19,7 +19,7 @@ def reload_modules(hf_hub_offline):
|
||||
importlib.reload(huggingface_hub.constants)
|
||||
huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline
|
||||
importlib.reload(datasets.config)
|
||||
datasets.config.HF_HUB_OFFLINE = hf_hub_offline
|
||||
setattr(datasets.config, "HF_HUB_OFFLINE", hf_hub_offline)
|
||||
reset_sessions()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user