use offline for precached stream dataset (#2453)

This commit is contained in:
Wing Lian
2025-03-28 23:39:09 -04:00
committed by GitHub
parent e46239f8d3
commit c49682132b
8 changed files with 179 additions and 124 deletions

View File

@@ -171,6 +171,9 @@ jobs:
run: | run: |
axolotl --help axolotl --help
- name: Show HF cache
run: huggingface-cli scan-cache
- name: Run tests - name: Run tests
run: | run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/

View File

@@ -11,8 +11,10 @@ import time
import pytest import pytest
import requests import requests
from datasets import load_dataset
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from utils import disable_hf_offline from transformers import AutoTokenizer
from utils import disable_hf_offline, enable_hf_offline
def retry_on_request_exceptions(max_retries=3, delay=1): def retry_on_request_exceptions(max_retries=3, delay=1):
@@ -46,7 +48,6 @@ def snapshot_download_w_retry(*args, **kwargs):
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_smollm2_135m_model(): def download_smollm2_135m_model():
# download the model # download the model
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M", repo_type="model") snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M", repo_type="model")
@@ -59,28 +60,24 @@ def download_llama_68m_random_model():
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_qwen_2_5_half_billion_model(): def download_qwen_2_5_half_billion_model():
# download the model # download the model
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model") snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_tatsu_lab_alpaca_dataset(): def download_tatsu_lab_alpaca_dataset():
# download the dataset # download the dataset
snapshot_download_w_retry("tatsu-lab/alpaca", repo_type="dataset") snapshot_download_w_retry("tatsu-lab/alpaca", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_mhenrichsen_alpaca_2k_dataset(): def download_mhenrichsen_alpaca_2k_dataset():
# download the dataset # download the dataset
snapshot_download_w_retry("mhenrichsen/alpaca_2k_test", repo_type="dataset") snapshot_download_w_retry("mhenrichsen/alpaca_2k_test", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_mhenrichsen_alpaca_2k_w_revision_dataset(): def download_mhenrichsen_alpaca_2k_w_revision_dataset():
# download the dataset # download the dataset
snapshot_download_w_retry( snapshot_download_w_retry(
@@ -89,7 +86,6 @@ def download_mhenrichsen_alpaca_2k_w_revision_dataset():
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_mlabonne_finetome_100k_dataset(): def download_mlabonne_finetome_100k_dataset():
# download the dataset # download the dataset
snapshot_download_w_retry("mlabonne/FineTome-100k", repo_type="dataset") snapshot_download_w_retry("mlabonne/FineTome-100k", repo_type="dataset")
@@ -124,6 +120,24 @@ def download_fozzie_alpaca_dpo_dataset():
) )
@pytest.fixture(scope="session")
@disable_hf_offline
def dataset_fozzie_alpaca_dpo_dataset(
download_fozzie_alpaca_dpo_dataset,
): # pylint: disable=unused-argument,redefined-outer-name
return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
@pytest.fixture(scope="session")
@disable_hf_offline
def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
download_fozzie_alpaca_dpo_dataset,
): # pylint: disable=unused-argument,redefined-outer-name
return load_dataset(
"fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset(): def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
# download the dataset # download the dataset
@@ -152,7 +166,6 @@ def download_deepseek_model_fixture():
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_huggyllama_model_fixture(): def download_huggyllama_model_fixture():
# download the tokenizer only # download the tokenizer only
snapshot_download_w_retry( snapshot_download_w_retry(
@@ -163,7 +176,6 @@ def download_huggyllama_model_fixture():
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_llama_1b_model_fixture(): def download_llama_1b_model_fixture():
# download the tokenizer only # download the tokenizer only
snapshot_download_w_retry( snapshot_download_w_retry(
@@ -174,7 +186,6 @@ def download_llama_1b_model_fixture():
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_llama3_8b_model_fixture(): def download_llama3_8b_model_fixture():
# download the tokenizer only # download the tokenizer only
snapshot_download_w_retry( snapshot_download_w_retry(
@@ -183,7 +194,6 @@ def download_llama3_8b_model_fixture():
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_llama3_8b_instruct_model_fixture(): def download_llama3_8b_instruct_model_fixture():
# download the tokenizer only # download the tokenizer only
snapshot_download_w_retry( snapshot_download_w_retry(
@@ -194,7 +204,6 @@ def download_llama3_8b_instruct_model_fixture():
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_phi_35_mini_model_fixture(): def download_phi_35_mini_model_fixture():
# download the tokenizer only # download the tokenizer only
snapshot_download_w_retry( snapshot_download_w_retry(
@@ -263,6 +272,17 @@ def download_llama2_model_fixture():
) )
@pytest.fixture(scope="session", autouse=True)
@enable_hf_offline
def tokenizer_huggyllama(
download_huggyllama_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
tokenizer.pad_token = "</s>"
return tokenizer
@pytest.fixture @pytest.fixture
def temp_dir(): def temp_dir():
# Create a temporary directory # Create a temporary directory

View File

@@ -109,7 +109,9 @@ def fixture_toolcalling_dataset():
@pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True) @pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True)
@enable_hf_offline @enable_hf_offline
def fixture_llama3_tokenizer(): def fixture_llama3_tokenizer(
download_llama3_8b_instruct_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct") tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
return tokenizer return tokenizer
@@ -123,7 +125,10 @@ def fixture_smollm2_tokenizer():
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True) @pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
def fixture_mistralv03_tokenizer(): @enable_hf_offline
def fixture_mistralv03_tokenizer(
download_mlx_mistral_7b_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
"mlx-community/Mistral-7B-Instruct-v0.3-4bit" "mlx-community/Mistral-7B-Instruct-v0.3-4bit"
) )

View File

@@ -9,6 +9,7 @@ import pytest
from datasets import Dataset from datasets import Dataset
from tokenizers import AddedToken from tokenizers import AddedToken
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from utils import enable_hf_offline
from axolotl.prompt_strategies.chat_template import ( from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter, ChatTemplatePrompter,
@@ -101,6 +102,7 @@ class TestChatTemplateConfigurations:
return True return True
return False return False
@enable_hf_offline
def test_train_on_inputs_true( def test_train_on_inputs_true(
self, self,
tokenizer, tokenizer,

View File

@@ -4,8 +4,8 @@ Test dataset loading under various conditions.
import shutil import shutil
import tempfile import tempfile
import unittest
from pathlib import Path from pathlib import Path
from unittest.mock import patch
import pytest import pytest
from constants import ( from constants import (
@@ -15,7 +15,7 @@ from constants import (
) )
from datasets import Dataset from datasets import Dataset
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import AutoTokenizer from transformers import PreTrainedTokenizer
from utils import enable_hf_offline from utils import enable_hf_offline
from axolotl.utils.data import load_tokenized_prepared_datasets from axolotl.utils.data import load_tokenized_prepared_datasets
@@ -23,15 +23,17 @@ from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
class TestDatasetPreparation(unittest.TestCase): class TestDatasetPreparation:
"""Test a configured dataloader.""" """Test a configured dataloader."""
@enable_hf_offline @pytest.fixture
def setUp(self) -> None: def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer:
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)
self.tokenizer.add_special_tokens(SPECIAL_TOKENS) yield tokenizer_huggyllama
# Alpaca dataset.
self.dataset = Dataset.from_list( @pytest.fixture
def dataset_fixture(self):
yield Dataset.from_list(
[ [
{ {
"instruction": "Evaluate this sentence for spelling and grammar mistakes", "instruction": "Evaluate this sentence for spelling and grammar mistakes",
@@ -43,7 +45,7 @@ class TestDatasetPreparation(unittest.TestCase):
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits") @pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline @enable_hf_offline
def test_load_hub(self): def test_load_hub(self, tokenizer):
"""Core use case. Verify that processing data from the hub works""" """Core use case. Verify that processing data from the hub works"""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared" prepared_path = Path(tmp_dir) / "prepared"
@@ -60,9 +62,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -71,7 +71,7 @@ class TestDatasetPreparation(unittest.TestCase):
@enable_hf_offline @enable_hf_offline
@pytest.mark.skip("datasets bug with local datasets when offline") @pytest.mark.skip("datasets bug with local datasets when offline")
def test_load_local_hub(self): def test_load_local_hub(self, tokenizer):
"""Niche use case. Verify that a local copy of a hub dataset can be loaded""" """Niche use case. Verify that a local copy of a hub dataset can be loaded"""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test" tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
@@ -106,9 +106,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -117,11 +115,11 @@ class TestDatasetPreparation(unittest.TestCase):
shutil.rmtree(tmp_ds_path) shutil.rmtree(tmp_ds_path)
@enable_hf_offline @enable_hf_offline
def test_load_from_save_to_disk(self): def test_load_from_save_to_disk(self, tokenizer, dataset_fixture):
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded.""" """Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_name = Path(tmp_dir) / "tmp_dataset" tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
self.dataset.save_to_disk(str(tmp_ds_name)) dataset_fixture.save_to_disk(str(tmp_ds_name))
prepared_path = Path(tmp_dir) / "prepared" prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(
@@ -137,9 +135,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -147,13 +143,13 @@ class TestDatasetPreparation(unittest.TestCase):
assert "labels" in dataset.features assert "labels" in dataset.features
@enable_hf_offline @enable_hf_offline
def test_load_from_dir_of_parquet(self): def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
"""Usual use case. Verify a directory of parquet files can be loaded.""" """Usual use case. Verify a directory of parquet files can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset" tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
tmp_ds_dir.mkdir() tmp_ds_dir.mkdir()
tmp_ds_path = tmp_ds_dir / "shard1.parquet" tmp_ds_path = tmp_ds_dir / "shard1.parquet"
self.dataset.to_parquet(tmp_ds_path) dataset_fixture.to_parquet(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared" prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(
@@ -174,9 +170,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -184,13 +178,13 @@ class TestDatasetPreparation(unittest.TestCase):
assert "labels" in dataset.features assert "labels" in dataset.features
@enable_hf_offline @enable_hf_offline
def test_load_from_dir_of_json(self): def test_load_from_dir_of_json(self, tokenizer, dataset_fixture):
"""Standard use case. Verify a directory of json files can be loaded.""" """Standard use case. Verify a directory of json files can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset" tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
tmp_ds_dir.mkdir() tmp_ds_dir.mkdir()
tmp_ds_path = tmp_ds_dir / "shard1.json" tmp_ds_path = tmp_ds_dir / "shard1.json"
self.dataset.to_json(tmp_ds_path) dataset_fixture.to_json(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared" prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(
@@ -211,9 +205,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -221,11 +213,11 @@ class TestDatasetPreparation(unittest.TestCase):
assert "labels" in dataset.features assert "labels" in dataset.features
@enable_hf_offline @enable_hf_offline
def test_load_from_single_parquet(self): def test_load_from_single_parquet(self, tokenizer, dataset_fixture):
"""Standard use case. Verify a single parquet file can be loaded.""" """Standard use case. Verify a single parquet file can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet" tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet"
self.dataset.to_parquet(tmp_ds_path) dataset_fixture.to_parquet(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared" prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(
@@ -242,9 +234,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -252,11 +242,11 @@ class TestDatasetPreparation(unittest.TestCase):
assert "labels" in dataset.features assert "labels" in dataset.features
@enable_hf_offline @enable_hf_offline
def test_load_from_single_json(self): def test_load_from_single_json(self, tokenizer, dataset_fixture):
"""Standard use case. Verify a single json file can be loaded.""" """Standard use case. Verify a single json file can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json" tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json"
self.dataset.to_json(tmp_ds_path) dataset_fixture.to_json(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared" prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault( cfg = DictDefault(
@@ -273,9 +263,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 1 assert len(dataset) == 1
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -304,7 +292,7 @@ class TestDatasetPreparation(unittest.TestCase):
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits") @pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline @enable_hf_offline
def test_load_hub_with_revision(self): def test_load_hub_with_revision(self, tokenizer):
"""Verify that processing data from the hub works with a specific revision""" """Verify that processing data from the hub works with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared" prepared_path = Path(tmp_dir) / "prepared"
@@ -326,9 +314,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -336,7 +322,9 @@ class TestDatasetPreparation(unittest.TestCase):
assert "labels" in dataset.features assert "labels" in dataset.features
@enable_hf_offline @enable_hf_offline
def test_load_hub_with_revision_with_dpo(self): def test_load_hub_with_revision_with_dpo(
self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
):
"""Verify that processing dpo data from the hub works with a specific revision""" """Verify that processing dpo data from the hub works with a specific revision"""
cfg = DictDefault( cfg = DictDefault(
@@ -349,14 +337,23 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
train_dataset, _ = load_prepare_preference_datasets(cfg) # pylint: disable=duplicate-code
with patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset:
# Set up the mock to return different values on successive calls
mock_load_dataset.return_value = (
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
)
assert len(train_dataset) == 1800 train_dataset, _ = load_prepare_preference_datasets(cfg)
assert "conversation" in train_dataset.features
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
@enable_hf_offline @enable_hf_offline
@pytest.mark.skip("datasets bug with local datasets when offline") @pytest.mark.skip("datasets bug with local datasets when offline")
def test_load_local_hub_with_revision(self): def test_load_local_hub_with_revision(self, tokenizer):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision""" """Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test" tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
@@ -388,9 +385,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
@@ -399,7 +394,7 @@ class TestDatasetPreparation(unittest.TestCase):
shutil.rmtree(tmp_ds_path) shutil.rmtree(tmp_ds_path)
@enable_hf_offline @enable_hf_offline
def test_loading_local_dataset_folder(self): def test_loading_local_dataset_folder(self, tokenizer):
"""Verify that a dataset downloaded to a local folder can be loaded""" """Verify that a dataset downloaded to a local folder can be loaded"""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
@@ -426,16 +421,10 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
dataset, _ = load_tokenized_prepared_datasets( dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000 assert len(dataset) == 2000
assert "input_ids" in dataset.features assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path) shutil.rmtree(tmp_ds_path)
if __name__ == "__main__":
unittest.main()

View File

@@ -9,9 +9,8 @@ import unittest
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS from constants import ALPACA_MESSAGES_CONFIG_REVISION
from datasets import Dataset from datasets import Dataset
from transformers import AutoTokenizer
from utils import enable_hf_offline from utils import enable_hf_offline
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
@@ -216,13 +215,12 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset") verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
class TestDeduplicateRLDataset(unittest.TestCase): class TestDeduplicateRLDataset:
"""Test a configured dataloader with deduplication.""" """Test a configured dataloader with deduplication."""
def setUp(self) -> None: @pytest.fixture
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") def cfg(self):
self.tokenizer.add_special_tokens(SPECIAL_TOKENS) fixture = DictDefault(
self.cfg = DictDefault(
{ {
"tokenizer_config": "huggyllama/llama-7b", "tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024, "sequence_len": 1024,
@@ -235,28 +233,59 @@ class TestDeduplicateRLDataset(unittest.TestCase):
], ],
} }
) )
yield fixture
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline @enable_hf_offline
def test_load_with_deduplication(self): def test_load_with_deduplication(
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
):
"""Verify that loading with deduplication removes duplicates.""" """Verify that loading with deduplication removes duplicates."""
# Load the dataset using the deduplication setting # pylint: disable=duplicate-code
train_dataset, _ = load_prepare_preference_datasets(self.cfg) with (
patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
# Verify that the dataset has been deduplicated train_dataset, _ = load_prepare_preference_datasets(cfg)
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
def test_load_without_deduplication(self): # Verify that the dataset has been deduplicated
"""Verify that loading without deduplication retains duplicates.""" assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
self.cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
# Verify that the dataset retains duplicates @enable_hf_offline
assert ( def test_load_without_deduplication(
len(train_dataset) == 1800 * 2 self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
), "Dataset deduplication occurred when it should not have" ):
# pylint: disable=duplicate-code
with (
patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication
train_dataset, _ = load_prepare_preference_datasets(cfg)
# Verify that the dataset retains duplicates
assert (
len(train_dataset) == 1800 * 2
), "Dataset deduplication occurred when it should not have"
class TestDeduplicateNonRL(unittest.TestCase): class TestDeduplicateNonRL(unittest.TestCase):
@@ -264,8 +293,6 @@ class TestDeduplicateNonRL(unittest.TestCase):
@enable_hf_offline @enable_hf_offline
def setUp(self) -> None: def setUp(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
self.cfg_1 = DictDefault( self.cfg_1 = DictDefault(
{ {
"base_model": "huggyllama/llama-7b", "base_model": "huggyllama/llama-7b",

View File

@@ -1,38 +1,50 @@
"""Module for testing streaming dataset sequence packing""" """Module for testing streaming dataset sequence packing"""
import functools import functools
import unittest import random
import string
import pytest import pytest
import torch import torch
from datasets import load_dataset from datasets import IterableDataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from utils import disable_hf_offline, enable_hf_offline
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
class TestPretrainingPacking(unittest.TestCase): class TestPretrainingPacking:
""" """
Test class for packing streaming dataset sequences Test class for packing streaming dataset sequences
""" """
@enable_hf_offline @pytest.fixture
def setUp(self) -> None: def random_text(self):
# pylint: disable=duplicate-code # seed with random.seed(0) for reproducibility
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") random.seed(0)
self.tokenizer.pad_token = "</s>"
# generate 20 rows of random text with "words" of between 2 and 10 characters and
# between 400 to 1200 characters per line
data = [
"".join(random.choices(string.ascii_lowercase, k=random.randint(2, 10)))
for _ in range(20)
] + [
" ".join(
random.choices(string.ascii_lowercase, k=random.randint(400, 1200))
)
for _ in range(20)
]
# Create an IterableDataset
def generator():
for text in data:
yield {"text": text}
return IterableDataset.from_generator(generator)
@pytest.mark.flaky(retries=1, delay=5) @pytest.mark.flaky(retries=1, delay=5)
@disable_hf_offline def test_packing_stream_dataset(self, tokenizer_huggyllama, random_text):
def test_packing_stream_dataset(self): dataset = random_text
# pylint: disable=duplicate-code
dataset = load_dataset(
"winglian/tiny-shakespeare",
streaming=True,
)["train"]
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -55,15 +67,16 @@ class TestPretrainingPacking(unittest.TestCase):
ds_wrapper_partial = functools.partial( ds_wrapper_partial = functools.partial(
get_dataset_wrapper, get_dataset_wrapper,
cfg.pretraining_dataset[0], cfg.pretraining_dataset[0],
self.tokenizer, tokenizer_huggyllama,
cfg, cfg,
cfg.pretraining_dataset[0]["type"] or "pretrain", cfg.pretraining_dataset[0]["type"] or "pretrain",
) )
# pylint: disable=duplicate-code
original_bsz = cfg.micro_batch_size original_bsz = cfg.micro_batch_size
train_dataset = wrap_pretraining_dataset( train_dataset = wrap_pretraining_dataset(
dataset, dataset,
self.tokenizer, tokenizer_huggyllama,
cfg, cfg,
ds_wrapper_partial, ds_wrapper_partial,
max_tokens=cfg.sequence_len, max_tokens=cfg.sequence_len,
@@ -96,7 +109,3 @@ class TestPretrainingPacking(unittest.TestCase):
# [1, original_bsz * cfg.sequence_len] # [1, original_bsz * cfg.sequence_len]
# ) # )
idx += 1 idx += 1
if __name__ == "__main__":
unittest.main()

View File

@@ -19,7 +19,7 @@ def reload_modules(hf_hub_offline):
importlib.reload(huggingface_hub.constants) importlib.reload(huggingface_hub.constants)
huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline
importlib.reload(datasets.config) importlib.reload(datasets.config)
datasets.config.HF_HUB_OFFLINE = hf_hub_offline setattr(datasets.config, "HF_HUB_OFFLINE", hf_hub_offline)
reset_sessions() reset_sessions()