use offline for precached stream dataset (#2453)
This commit is contained in:
3
.github/workflows/tests.yml
vendored
3
.github/workflows/tests.yml
vendored
@@ -171,6 +171,9 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
axolotl --help
|
axolotl --help
|
||||||
|
|
||||||
|
- name: Show HF cache
|
||||||
|
run: huggingface-cli scan-cache
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
|
|||||||
@@ -11,8 +11,10 @@ import time
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
from datasets import load_dataset
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from utils import disable_hf_offline
|
from transformers import AutoTokenizer
|
||||||
|
from utils import disable_hf_offline, enable_hf_offline
|
||||||
|
|
||||||
|
|
||||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||||
@@ -46,7 +48,6 @@ def snapshot_download_w_retry(*args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
@disable_hf_offline
|
|
||||||
def download_smollm2_135m_model():
|
def download_smollm2_135m_model():
|
||||||
# download the model
|
# download the model
|
||||||
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M", repo_type="model")
|
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M", repo_type="model")
|
||||||
@@ -59,28 +60,24 @@ def download_llama_68m_random_model():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
@disable_hf_offline
|
|
||||||
def download_qwen_2_5_half_billion_model():
|
def download_qwen_2_5_half_billion_model():
|
||||||
# download the model
|
# download the model
|
||||||
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
|
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
@disable_hf_offline
|
|
||||||
def download_tatsu_lab_alpaca_dataset():
|
def download_tatsu_lab_alpaca_dataset():
|
||||||
# download the dataset
|
# download the dataset
|
||||||
snapshot_download_w_retry("tatsu-lab/alpaca", repo_type="dataset")
|
snapshot_download_w_retry("tatsu-lab/alpaca", repo_type="dataset")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
@disable_hf_offline
|
|
||||||
def download_mhenrichsen_alpaca_2k_dataset():
|
def download_mhenrichsen_alpaca_2k_dataset():
|
||||||
# download the dataset
|
# download the dataset
|
||||||
snapshot_download_w_retry("mhenrichsen/alpaca_2k_test", repo_type="dataset")
|
snapshot_download_w_retry("mhenrichsen/alpaca_2k_test", repo_type="dataset")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
@disable_hf_offline
|
|
||||||
def download_mhenrichsen_alpaca_2k_w_revision_dataset():
|
def download_mhenrichsen_alpaca_2k_w_revision_dataset():
|
||||||
# download the dataset
|
# download the dataset
|
||||||
snapshot_download_w_retry(
|
snapshot_download_w_retry(
|
||||||
@@ -89,7 +86,6 @@ def download_mhenrichsen_alpaca_2k_w_revision_dataset():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
@disable_hf_offline
|
|
||||||
def download_mlabonne_finetome_100k_dataset():
|
def download_mlabonne_finetome_100k_dataset():
|
||||||
# download the dataset
|
# download the dataset
|
||||||
snapshot_download_w_retry("mlabonne/FineTome-100k", repo_type="dataset")
|
snapshot_download_w_retry("mlabonne/FineTome-100k", repo_type="dataset")
|
||||||
@@ -124,6 +120,24 @@ def download_fozzie_alpaca_dpo_dataset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
@disable_hf_offline
|
||||||
|
def dataset_fozzie_alpaca_dpo_dataset(
|
||||||
|
download_fozzie_alpaca_dpo_dataset,
|
||||||
|
): # pylint: disable=unused-argument,redefined-outer-name
|
||||||
|
return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
@disable_hf_offline
|
||||||
|
def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
|
||||||
|
download_fozzie_alpaca_dpo_dataset,
|
||||||
|
): # pylint: disable=unused-argument,redefined-outer-name
|
||||||
|
return load_dataset(
|
||||||
|
"fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
||||||
# download the dataset
|
# download the dataset
|
||||||
@@ -152,7 +166,6 @@ def download_deepseek_model_fixture():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
@disable_hf_offline
|
|
||||||
def download_huggyllama_model_fixture():
|
def download_huggyllama_model_fixture():
|
||||||
# download the tokenizer only
|
# download the tokenizer only
|
||||||
snapshot_download_w_retry(
|
snapshot_download_w_retry(
|
||||||
@@ -163,7 +176,6 @@ def download_huggyllama_model_fixture():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
@disable_hf_offline
|
|
||||||
def download_llama_1b_model_fixture():
|
def download_llama_1b_model_fixture():
|
||||||
# download the tokenizer only
|
# download the tokenizer only
|
||||||
snapshot_download_w_retry(
|
snapshot_download_w_retry(
|
||||||
@@ -174,7 +186,6 @@ def download_llama_1b_model_fixture():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
@disable_hf_offline
|
|
||||||
def download_llama3_8b_model_fixture():
|
def download_llama3_8b_model_fixture():
|
||||||
# download the tokenizer only
|
# download the tokenizer only
|
||||||
snapshot_download_w_retry(
|
snapshot_download_w_retry(
|
||||||
@@ -183,7 +194,6 @@ def download_llama3_8b_model_fixture():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
@disable_hf_offline
|
|
||||||
def download_llama3_8b_instruct_model_fixture():
|
def download_llama3_8b_instruct_model_fixture():
|
||||||
# download the tokenizer only
|
# download the tokenizer only
|
||||||
snapshot_download_w_retry(
|
snapshot_download_w_retry(
|
||||||
@@ -194,7 +204,6 @@ def download_llama3_8b_instruct_model_fixture():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
@disable_hf_offline
|
|
||||||
def download_phi_35_mini_model_fixture():
|
def download_phi_35_mini_model_fixture():
|
||||||
# download the tokenizer only
|
# download the tokenizer only
|
||||||
snapshot_download_w_retry(
|
snapshot_download_w_retry(
|
||||||
@@ -263,6 +272,17 @@ def download_llama2_model_fixture():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
@enable_hf_offline
|
||||||
|
def tokenizer_huggyllama(
|
||||||
|
download_huggyllama_model_fixture,
|
||||||
|
): # pylint: disable=unused-argument,redefined-outer-name
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
|
tokenizer.pad_token = "</s>"
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def temp_dir():
|
def temp_dir():
|
||||||
# Create a temporary directory
|
# Create a temporary directory
|
||||||
|
|||||||
@@ -109,7 +109,9 @@ def fixture_toolcalling_dataset():
|
|||||||
|
|
||||||
@pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True)
|
@pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True)
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def fixture_llama3_tokenizer():
|
def fixture_llama3_tokenizer(
|
||||||
|
download_llama3_8b_instruct_model_fixture,
|
||||||
|
): # pylint: disable=unused-argument,redefined-outer-name
|
||||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
@@ -123,7 +125,10 @@ def fixture_smollm2_tokenizer():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
|
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
|
||||||
def fixture_mistralv03_tokenizer():
|
@enable_hf_offline
|
||||||
|
def fixture_mistralv03_tokenizer(
|
||||||
|
download_mlx_mistral_7b_model_fixture,
|
||||||
|
): # pylint: disable=unused-argument,redefined-outer-name
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
"mlx-community/Mistral-7B-Instruct-v0.3-4bit"
|
"mlx-community/Mistral-7B-Instruct-v0.3-4bit"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import pytest
|
|||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from tokenizers import AddedToken
|
from tokenizers import AddedToken
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
from utils import enable_hf_offline
|
||||||
|
|
||||||
from axolotl.prompt_strategies.chat_template import (
|
from axolotl.prompt_strategies.chat_template import (
|
||||||
ChatTemplatePrompter,
|
ChatTemplatePrompter,
|
||||||
@@ -101,6 +102,7 @@ class TestChatTemplateConfigurations:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
def test_train_on_inputs_true(
|
def test_train_on_inputs_true(
|
||||||
self,
|
self,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ Test dataset loading under various conditions.
|
|||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from constants import (
|
from constants import (
|
||||||
@@ -15,7 +15,7 @@ from constants import (
|
|||||||
)
|
)
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers import AutoTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
from utils import enable_hf_offline
|
from utils import enable_hf_offline
|
||||||
|
|
||||||
from axolotl.utils.data import load_tokenized_prepared_datasets
|
from axolotl.utils.data import load_tokenized_prepared_datasets
|
||||||
@@ -23,15 +23,17 @@ from axolotl.utils.data.rl import load_prepare_preference_datasets
|
|||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
class TestDatasetPreparation(unittest.TestCase):
|
class TestDatasetPreparation:
|
||||||
"""Test a configured dataloader."""
|
"""Test a configured dataloader."""
|
||||||
|
|
||||||
@enable_hf_offline
|
@pytest.fixture
|
||||||
def setUp(self) -> None:
|
def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)
|
||||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
yield tokenizer_huggyllama
|
||||||
# Alpaca dataset.
|
|
||||||
self.dataset = Dataset.from_list(
|
@pytest.fixture
|
||||||
|
def dataset_fixture(self):
|
||||||
|
yield Dataset.from_list(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
|
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
|
||||||
@@ -43,7 +45,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
|
|
||||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def test_load_hub(self):
|
def test_load_hub(self, tokenizer):
|
||||||
"""Core use case. Verify that processing data from the hub works"""
|
"""Core use case. Verify that processing data from the hub works"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
@@ -60,9 +62,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
@@ -71,7 +71,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
@pytest.mark.skip("datasets bug with local datasets when offline")
|
@pytest.mark.skip("datasets bug with local datasets when offline")
|
||||||
def test_load_local_hub(self):
|
def test_load_local_hub(self, tokenizer):
|
||||||
"""Niche use case. Verify that a local copy of a hub dataset can be loaded"""
|
"""Niche use case. Verify that a local copy of a hub dataset can be loaded"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||||
@@ -106,9 +106,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
@@ -117,11 +115,11 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
shutil.rmtree(tmp_ds_path)
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def test_load_from_save_to_disk(self):
|
def test_load_from_save_to_disk(self, tokenizer, dataset_fixture):
|
||||||
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
|
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
|
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
|
||||||
self.dataset.save_to_disk(str(tmp_ds_name))
|
dataset_fixture.save_to_disk(str(tmp_ds_name))
|
||||||
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -137,9 +135,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
@@ -147,13 +143,13 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def test_load_from_dir_of_parquet(self):
|
def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
|
||||||
"""Usual use case. Verify a directory of parquet files can be loaded."""
|
"""Usual use case. Verify a directory of parquet files can be loaded."""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
||||||
tmp_ds_dir.mkdir()
|
tmp_ds_dir.mkdir()
|
||||||
tmp_ds_path = tmp_ds_dir / "shard1.parquet"
|
tmp_ds_path = tmp_ds_dir / "shard1.parquet"
|
||||||
self.dataset.to_parquet(tmp_ds_path)
|
dataset_fixture.to_parquet(tmp_ds_path)
|
||||||
|
|
||||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -174,9 +170,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
@@ -184,13 +178,13 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def test_load_from_dir_of_json(self):
|
def test_load_from_dir_of_json(self, tokenizer, dataset_fixture):
|
||||||
"""Standard use case. Verify a directory of json files can be loaded."""
|
"""Standard use case. Verify a directory of json files can be loaded."""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
||||||
tmp_ds_dir.mkdir()
|
tmp_ds_dir.mkdir()
|
||||||
tmp_ds_path = tmp_ds_dir / "shard1.json"
|
tmp_ds_path = tmp_ds_dir / "shard1.json"
|
||||||
self.dataset.to_json(tmp_ds_path)
|
dataset_fixture.to_json(tmp_ds_path)
|
||||||
|
|
||||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -211,9 +205,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
@@ -221,11 +213,11 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def test_load_from_single_parquet(self):
|
def test_load_from_single_parquet(self, tokenizer, dataset_fixture):
|
||||||
"""Standard use case. Verify a single parquet file can be loaded."""
|
"""Standard use case. Verify a single parquet file can be loaded."""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet"
|
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet"
|
||||||
self.dataset.to_parquet(tmp_ds_path)
|
dataset_fixture.to_parquet(tmp_ds_path)
|
||||||
|
|
||||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -242,9 +234,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
@@ -252,11 +242,11 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def test_load_from_single_json(self):
|
def test_load_from_single_json(self, tokenizer, dataset_fixture):
|
||||||
"""Standard use case. Verify a single json file can be loaded."""
|
"""Standard use case. Verify a single json file can be loaded."""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json"
|
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json"
|
||||||
self.dataset.to_json(tmp_ds_path)
|
dataset_fixture.to_json(tmp_ds_path)
|
||||||
|
|
||||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -273,9 +263,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
@@ -304,7 +292,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
|
|
||||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def test_load_hub_with_revision(self):
|
def test_load_hub_with_revision(self, tokenizer):
|
||||||
"""Verify that processing data from the hub works with a specific revision"""
|
"""Verify that processing data from the hub works with a specific revision"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
@@ -326,9 +314,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
@@ -336,7 +322,9 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def test_load_hub_with_revision_with_dpo(self):
|
def test_load_hub_with_revision_with_dpo(
|
||||||
|
self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
|
||||||
|
):
|
||||||
"""Verify that processing dpo data from the hub works with a specific revision"""
|
"""Verify that processing dpo data from the hub works with a specific revision"""
|
||||||
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -349,14 +337,23 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
# pylint: disable=duplicate-code
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||||
|
) as mock_load_dataset:
|
||||||
|
# Set up the mock to return different values on successive calls
|
||||||
|
mock_load_dataset.return_value = (
|
||||||
|
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
|
||||||
|
)
|
||||||
|
|
||||||
assert len(train_dataset) == 1800
|
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||||
assert "conversation" in train_dataset.features
|
|
||||||
|
assert len(train_dataset) == 1800
|
||||||
|
assert "conversation" in train_dataset.features
|
||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
@pytest.mark.skip("datasets bug with local datasets when offline")
|
@pytest.mark.skip("datasets bug with local datasets when offline")
|
||||||
def test_load_local_hub_with_revision(self):
|
def test_load_local_hub_with_revision(self, tokenizer):
|
||||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||||
@@ -388,9 +385,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
@@ -399,7 +394,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
shutil.rmtree(tmp_ds_path)
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def test_loading_local_dataset_folder(self):
|
def test_loading_local_dataset_folder(self, tokenizer):
|
||||||
"""Verify that a dataset downloaded to a local folder can be loaded"""
|
"""Verify that a dataset downloaded to a local folder can be loaded"""
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
@@ -426,16 +421,10 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||||
self.tokenizer, cfg, prepared_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
shutil.rmtree(tmp_ds_path)
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
|
|||||||
@@ -9,9 +9,8 @@ import unittest
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
|
from constants import ALPACA_MESSAGES_CONFIG_REVISION
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from transformers import AutoTokenizer
|
|
||||||
from utils import enable_hf_offline
|
from utils import enable_hf_offline
|
||||||
|
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
@@ -216,13 +215,12 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
|||||||
verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
|
verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
|
||||||
|
|
||||||
|
|
||||||
class TestDeduplicateRLDataset(unittest.TestCase):
|
class TestDeduplicateRLDataset:
|
||||||
"""Test a configured dataloader with deduplication."""
|
"""Test a configured dataloader with deduplication."""
|
||||||
|
|
||||||
def setUp(self) -> None:
|
@pytest.fixture
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
def cfg(self):
|
||||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
fixture = DictDefault(
|
||||||
self.cfg = DictDefault(
|
|
||||||
{
|
{
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
@@ -235,28 +233,59 @@ class TestDeduplicateRLDataset(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
yield fixture
|
||||||
|
|
||||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def test_load_with_deduplication(self):
|
def test_load_with_deduplication(
|
||||||
|
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
|
||||||
|
):
|
||||||
"""Verify that loading with deduplication removes duplicates."""
|
"""Verify that loading with deduplication removes duplicates."""
|
||||||
|
|
||||||
# Load the dataset using the deduplication setting
|
# pylint: disable=duplicate-code
|
||||||
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
|
with (
|
||||||
|
patch(
|
||||||
|
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||||
|
) as mock_load_dataset,
|
||||||
|
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||||
|
):
|
||||||
|
# Set up the mock to return different values on successive calls
|
||||||
|
mock_load_dataset.side_effect = [
|
||||||
|
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||||
|
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||||
|
]
|
||||||
|
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||||
|
|
||||||
# Verify that the dataset has been deduplicated
|
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||||
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
|
||||||
|
|
||||||
def test_load_without_deduplication(self):
|
# Verify that the dataset has been deduplicated
|
||||||
"""Verify that loading without deduplication retains duplicates."""
|
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
||||||
self.cfg.dataset_exact_deduplication = False
|
|
||||||
# Load the dataset without deduplication
|
|
||||||
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
|
|
||||||
|
|
||||||
# Verify that the dataset retains duplicates
|
@enable_hf_offline
|
||||||
assert (
|
def test_load_without_deduplication(
|
||||||
len(train_dataset) == 1800 * 2
|
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
|
||||||
), "Dataset deduplication occurred when it should not have"
|
):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||||
|
) as mock_load_dataset,
|
||||||
|
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||||
|
):
|
||||||
|
# Set up the mock to return different values on successive calls
|
||||||
|
mock_load_dataset.side_effect = [
|
||||||
|
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||||
|
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||||
|
]
|
||||||
|
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||||
|
|
||||||
|
cfg.dataset_exact_deduplication = False
|
||||||
|
# Load the dataset without deduplication
|
||||||
|
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||||
|
|
||||||
|
# Verify that the dataset retains duplicates
|
||||||
|
assert (
|
||||||
|
len(train_dataset) == 1800 * 2
|
||||||
|
), "Dataset deduplication occurred when it should not have"
|
||||||
|
|
||||||
|
|
||||||
class TestDeduplicateNonRL(unittest.TestCase):
|
class TestDeduplicateNonRL(unittest.TestCase):
|
||||||
@@ -264,8 +293,6 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
|||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
|
||||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
|
||||||
self.cfg_1 = DictDefault(
|
self.cfg_1 = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "huggyllama/llama-7b",
|
"base_model": "huggyllama/llama-7b",
|
||||||
|
|||||||
@@ -1,38 +1,50 @@
|
|||||||
"""Module for testing streaming dataset sequence packing"""
|
"""Module for testing streaming dataset sequence packing"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import unittest
|
import random
|
||||||
|
import string
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import IterableDataset
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from transformers import AutoTokenizer
|
|
||||||
from utils import disable_hf_offline, enable_hf_offline
|
|
||||||
|
|
||||||
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
|
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
class TestPretrainingPacking(unittest.TestCase):
|
class TestPretrainingPacking:
|
||||||
"""
|
"""
|
||||||
Test class for packing streaming dataset sequences
|
Test class for packing streaming dataset sequences
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@enable_hf_offline
|
@pytest.fixture
|
||||||
def setUp(self) -> None:
|
def random_text(self):
|
||||||
# pylint: disable=duplicate-code
|
# seed with random.seed(0) for reproducibility
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
random.seed(0)
|
||||||
self.tokenizer.pad_token = "</s>"
|
|
||||||
|
# generate 20 rows of random text with "words" of between 2 and 10 characters and
|
||||||
|
# between 400 to 1200 characters per line
|
||||||
|
data = [
|
||||||
|
"".join(random.choices(string.ascii_lowercase, k=random.randint(2, 10)))
|
||||||
|
for _ in range(20)
|
||||||
|
] + [
|
||||||
|
" ".join(
|
||||||
|
random.choices(string.ascii_lowercase, k=random.randint(400, 1200))
|
||||||
|
)
|
||||||
|
for _ in range(20)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create an IterableDataset
|
||||||
|
def generator():
|
||||||
|
for text in data:
|
||||||
|
yield {"text": text}
|
||||||
|
|
||||||
|
return IterableDataset.from_generator(generator)
|
||||||
|
|
||||||
@pytest.mark.flaky(retries=1, delay=5)
|
@pytest.mark.flaky(retries=1, delay=5)
|
||||||
@disable_hf_offline
|
def test_packing_stream_dataset(self, tokenizer_huggyllama, random_text):
|
||||||
def test_packing_stream_dataset(self):
|
dataset = random_text
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
dataset = load_dataset(
|
|
||||||
"winglian/tiny-shakespeare",
|
|
||||||
streaming=True,
|
|
||||||
)["train"]
|
|
||||||
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -55,15 +67,16 @@ class TestPretrainingPacking(unittest.TestCase):
|
|||||||
ds_wrapper_partial = functools.partial(
|
ds_wrapper_partial = functools.partial(
|
||||||
get_dataset_wrapper,
|
get_dataset_wrapper,
|
||||||
cfg.pretraining_dataset[0],
|
cfg.pretraining_dataset[0],
|
||||||
self.tokenizer,
|
tokenizer_huggyllama,
|
||||||
cfg,
|
cfg,
|
||||||
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
original_bsz = cfg.micro_batch_size
|
original_bsz = cfg.micro_batch_size
|
||||||
train_dataset = wrap_pretraining_dataset(
|
train_dataset = wrap_pretraining_dataset(
|
||||||
dataset,
|
dataset,
|
||||||
self.tokenizer,
|
tokenizer_huggyllama,
|
||||||
cfg,
|
cfg,
|
||||||
ds_wrapper_partial,
|
ds_wrapper_partial,
|
||||||
max_tokens=cfg.sequence_len,
|
max_tokens=cfg.sequence_len,
|
||||||
@@ -96,7 +109,3 @@ class TestPretrainingPacking(unittest.TestCase):
|
|||||||
# [1, original_bsz * cfg.sequence_len]
|
# [1, original_bsz * cfg.sequence_len]
|
||||||
# )
|
# )
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ def reload_modules(hf_hub_offline):
|
|||||||
importlib.reload(huggingface_hub.constants)
|
importlib.reload(huggingface_hub.constants)
|
||||||
huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline
|
huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline
|
||||||
importlib.reload(datasets.config)
|
importlib.reload(datasets.config)
|
||||||
datasets.config.HF_HUB_OFFLINE = hf_hub_offline
|
setattr(datasets.config, "HF_HUB_OFFLINE", hf_hub_offline)
|
||||||
reset_sessions()
|
reset_sessions()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user