use offline for precached stream dataset (#2453)

This commit is contained in:
Wing Lian
2025-03-28 23:39:09 -04:00
committed by GitHub
parent e46239f8d3
commit c49682132b
8 changed files with 179 additions and 124 deletions

View File

@@ -171,6 +171,9 @@ jobs:
run: |
axolotl --help
- name: Show HF cache
run: huggingface-cli scan-cache
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/

View File

@@ -11,8 +11,10 @@ import time
import pytest
import requests
from datasets import load_dataset
from huggingface_hub import snapshot_download
from utils import disable_hf_offline
from transformers import AutoTokenizer
from utils import disable_hf_offline, enable_hf_offline
def retry_on_request_exceptions(max_retries=3, delay=1):
@@ -46,7 +48,6 @@ def snapshot_download_w_retry(*args, **kwargs):
@pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_smollm2_135m_model():
# download the model
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M", repo_type="model")
@@ -59,28 +60,24 @@ def download_llama_68m_random_model():
@pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_qwen_2_5_half_billion_model():
# download the model
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_tatsu_lab_alpaca_dataset():
# download the dataset
snapshot_download_w_retry("tatsu-lab/alpaca", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_mhenrichsen_alpaca_2k_dataset():
# download the dataset
snapshot_download_w_retry("mhenrichsen/alpaca_2k_test", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_mhenrichsen_alpaca_2k_w_revision_dataset():
# download the dataset
snapshot_download_w_retry(
@@ -89,7 +86,6 @@ def download_mhenrichsen_alpaca_2k_w_revision_dataset():
@pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_mlabonne_finetome_100k_dataset():
# download the dataset
snapshot_download_w_retry("mlabonne/FineTome-100k", repo_type="dataset")
@@ -124,6 +120,24 @@ def download_fozzie_alpaca_dpo_dataset():
)
@pytest.fixture(scope="session")
@disable_hf_offline
def dataset_fozzie_alpaca_dpo_dataset(
download_fozzie_alpaca_dpo_dataset,
): # pylint: disable=unused-argument,redefined-outer-name
return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
@pytest.fixture(scope="session")
@disable_hf_offline
def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
download_fozzie_alpaca_dpo_dataset,
): # pylint: disable=unused-argument,redefined-outer-name
return load_dataset(
"fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
)
@pytest.fixture(scope="session", autouse=True)
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
# download the dataset
@@ -152,7 +166,6 @@ def download_deepseek_model_fixture():
@pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_huggyllama_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
@@ -163,7 +176,6 @@ def download_huggyllama_model_fixture():
@pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_llama_1b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
@@ -174,7 +186,6 @@ def download_llama_1b_model_fixture():
@pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_llama3_8b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
@@ -183,7 +194,6 @@ def download_llama3_8b_model_fixture():
@pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_llama3_8b_instruct_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
@@ -194,7 +204,6 @@ def download_llama3_8b_instruct_model_fixture():
@pytest.fixture(scope="session", autouse=True)
@disable_hf_offline
def download_phi_35_mini_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
@@ -263,6 +272,17 @@ def download_llama2_model_fixture():
)
@pytest.fixture(scope="session", autouse=True)
@enable_hf_offline
def tokenizer_huggyllama(
download_huggyllama_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
tokenizer.pad_token = "</s>"
return tokenizer
@pytest.fixture
def temp_dir():
# Create a temporary directory

View File

@@ -109,7 +109,9 @@ def fixture_toolcalling_dataset():
@pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True)
@enable_hf_offline
def fixture_llama3_tokenizer():
def fixture_llama3_tokenizer(
download_llama3_8b_instruct_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
return tokenizer
@@ -123,7 +125,10 @@ def fixture_smollm2_tokenizer():
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
def fixture_mistralv03_tokenizer():
@enable_hf_offline
def fixture_mistralv03_tokenizer(
download_mlx_mistral_7b_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained(
"mlx-community/Mistral-7B-Instruct-v0.3-4bit"
)

View File

@@ -9,6 +9,7 @@ import pytest
from datasets import Dataset
from tokenizers import AddedToken
from transformers import PreTrainedTokenizer
from utils import enable_hf_offline
from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter,
@@ -101,6 +102,7 @@ class TestChatTemplateConfigurations:
return True
return False
@enable_hf_offline
def test_train_on_inputs_true(
self,
tokenizer,

View File

@@ -4,8 +4,8 @@ Test dataset loading under various conditions.
import shutil
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
import pytest
from constants import (
@@ -15,7 +15,7 @@ from constants import (
)
from datasets import Dataset
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from transformers import PreTrainedTokenizer
from utils import enable_hf_offline
from axolotl.utils.data import load_tokenized_prepared_datasets
@@ -23,15 +23,17 @@ from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
class TestDatasetPreparation(unittest.TestCase):
class TestDatasetPreparation:
"""Test a configured dataloader."""
@enable_hf_offline
def setUp(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
# Alpaca dataset.
self.dataset = Dataset.from_list(
@pytest.fixture
def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer:
tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)
yield tokenizer_huggyllama
@pytest.fixture
def dataset_fixture(self):
yield Dataset.from_list(
[
{
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
@@ -43,7 +45,7 @@ class TestDatasetPreparation(unittest.TestCase):
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_load_hub(self):
def test_load_hub(self, tokenizer):
"""Core use case. Verify that processing data from the hub works"""
with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared"
@@ -60,9 +62,7 @@ class TestDatasetPreparation(unittest.TestCase):
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
@@ -71,7 +71,7 @@ class TestDatasetPreparation(unittest.TestCase):
@enable_hf_offline
@pytest.mark.skip("datasets bug with local datasets when offline")
def test_load_local_hub(self):
def test_load_local_hub(self, tokenizer):
"""Niche use case. Verify that a local copy of a hub dataset can be loaded"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
@@ -106,9 +106,7 @@ class TestDatasetPreparation(unittest.TestCase):
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
@@ -117,11 +115,11 @@ class TestDatasetPreparation(unittest.TestCase):
shutil.rmtree(tmp_ds_path)
@enable_hf_offline
def test_load_from_save_to_disk(self):
def test_load_from_save_to_disk(self, tokenizer, dataset_fixture):
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
self.dataset.save_to_disk(str(tmp_ds_name))
dataset_fixture.save_to_disk(str(tmp_ds_name))
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
@@ -137,9 +135,7 @@ class TestDatasetPreparation(unittest.TestCase):
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
assert len(dataset) == 1
assert "input_ids" in dataset.features
@@ -147,13 +143,13 @@ class TestDatasetPreparation(unittest.TestCase):
assert "labels" in dataset.features
@enable_hf_offline
def test_load_from_dir_of_parquet(self):
def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
"""Usual use case. Verify a directory of parquet files can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
tmp_ds_dir.mkdir()
tmp_ds_path = tmp_ds_dir / "shard1.parquet"
self.dataset.to_parquet(tmp_ds_path)
dataset_fixture.to_parquet(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
@@ -174,9 +170,7 @@ class TestDatasetPreparation(unittest.TestCase):
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
assert len(dataset) == 1
assert "input_ids" in dataset.features
@@ -184,13 +178,13 @@ class TestDatasetPreparation(unittest.TestCase):
assert "labels" in dataset.features
@enable_hf_offline
def test_load_from_dir_of_json(self):
def test_load_from_dir_of_json(self, tokenizer, dataset_fixture):
"""Standard use case. Verify a directory of json files can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
tmp_ds_dir.mkdir()
tmp_ds_path = tmp_ds_dir / "shard1.json"
self.dataset.to_json(tmp_ds_path)
dataset_fixture.to_json(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
@@ -211,9 +205,7 @@ class TestDatasetPreparation(unittest.TestCase):
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
assert len(dataset) == 1
assert "input_ids" in dataset.features
@@ -221,11 +213,11 @@ class TestDatasetPreparation(unittest.TestCase):
assert "labels" in dataset.features
@enable_hf_offline
def test_load_from_single_parquet(self):
def test_load_from_single_parquet(self, tokenizer, dataset_fixture):
"""Standard use case. Verify a single parquet file can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet"
self.dataset.to_parquet(tmp_ds_path)
dataset_fixture.to_parquet(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
@@ -242,9 +234,7 @@ class TestDatasetPreparation(unittest.TestCase):
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
assert len(dataset) == 1
assert "input_ids" in dataset.features
@@ -252,11 +242,11 @@ class TestDatasetPreparation(unittest.TestCase):
assert "labels" in dataset.features
@enable_hf_offline
def test_load_from_single_json(self):
def test_load_from_single_json(self, tokenizer, dataset_fixture):
"""Standard use case. Verify a single json file can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json"
self.dataset.to_json(tmp_ds_path)
dataset_fixture.to_json(tmp_ds_path)
prepared_path: Path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
@@ -273,9 +263,7 @@ class TestDatasetPreparation(unittest.TestCase):
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
assert len(dataset) == 1
assert "input_ids" in dataset.features
@@ -304,7 +292,7 @@ class TestDatasetPreparation(unittest.TestCase):
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_load_hub_with_revision(self):
def test_load_hub_with_revision(self, tokenizer):
"""Verify that processing data from the hub works with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared"
@@ -326,9 +314,7 @@ class TestDatasetPreparation(unittest.TestCase):
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
@@ -336,7 +322,9 @@ class TestDatasetPreparation(unittest.TestCase):
assert "labels" in dataset.features
@enable_hf_offline
def test_load_hub_with_revision_with_dpo(self):
def test_load_hub_with_revision_with_dpo(
self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
):
"""Verify that processing dpo data from the hub works with a specific revision"""
cfg = DictDefault(
@@ -349,14 +337,23 @@ class TestDatasetPreparation(unittest.TestCase):
}
)
train_dataset, _ = load_prepare_preference_datasets(cfg)
# pylint: disable=duplicate-code
with patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset:
# Set up the mock to return different values on successive calls
mock_load_dataset.return_value = (
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
)
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
train_dataset, _ = load_prepare_preference_datasets(cfg)
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
@enable_hf_offline
@pytest.mark.skip("datasets bug with local datasets when offline")
def test_load_local_hub_with_revision(self):
def test_load_local_hub_with_revision(self, tokenizer):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
@@ -388,9 +385,7 @@ class TestDatasetPreparation(unittest.TestCase):
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
@@ -399,7 +394,7 @@ class TestDatasetPreparation(unittest.TestCase):
shutil.rmtree(tmp_ds_path)
@enable_hf_offline
def test_loading_local_dataset_folder(self):
def test_loading_local_dataset_folder(self, tokenizer):
"""Verify that a dataset downloaded to a local folder can be loaded"""
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -426,16 +421,10 @@ class TestDatasetPreparation(unittest.TestCase):
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
if __name__ == "__main__":
unittest.main()

View File

@@ -9,9 +9,8 @@ import unittest
from unittest.mock import patch
import pytest
from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
from constants import ALPACA_MESSAGES_CONFIG_REVISION
from datasets import Dataset
from transformers import AutoTokenizer
from utils import enable_hf_offline
from axolotl.utils.config import normalize_config
@@ -216,13 +215,12 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
class TestDeduplicateRLDataset(unittest.TestCase):
class TestDeduplicateRLDataset:
"""Test a configured dataloader with deduplication."""
def setUp(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
self.cfg = DictDefault(
@pytest.fixture
def cfg(self):
fixture = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
@@ -235,28 +233,59 @@ class TestDeduplicateRLDataset(unittest.TestCase):
],
}
)
yield fixture
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_load_with_deduplication(self):
def test_load_with_deduplication(
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
):
"""Verify that loading with deduplication removes duplicates."""
# Load the dataset using the deduplication setting
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
# pylint: disable=duplicate-code
with (
patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
# Verify that the dataset has been deduplicated
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
train_dataset, _ = load_prepare_preference_datasets(cfg)
def test_load_without_deduplication(self):
"""Verify that loading without deduplication retains duplicates."""
self.cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
# Verify that the dataset has been deduplicated
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
# Verify that the dataset retains duplicates
assert (
len(train_dataset) == 1800 * 2
), "Dataset deduplication occurred when it should not have"
@enable_hf_offline
def test_load_without_deduplication(
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
):
# pylint: disable=duplicate-code
with (
patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication
train_dataset, _ = load_prepare_preference_datasets(cfg)
# Verify that the dataset retains duplicates
assert (
len(train_dataset) == 1800 * 2
), "Dataset deduplication occurred when it should not have"
class TestDeduplicateNonRL(unittest.TestCase):
@@ -264,8 +293,6 @@ class TestDeduplicateNonRL(unittest.TestCase):
@enable_hf_offline
def setUp(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
self.cfg_1 = DictDefault(
{
"base_model": "huggyllama/llama-7b",

View File

@@ -1,38 +1,50 @@
"""Module for testing streaming dataset sequence packing"""
import functools
import unittest
import random
import string
import pytest
import torch
from datasets import load_dataset
from datasets import IterableDataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from utils import disable_hf_offline, enable_hf_offline
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
from axolotl.utils.dict import DictDefault
class TestPretrainingPacking(unittest.TestCase):
class TestPretrainingPacking:
"""
Test class for packing streaming dataset sequences
"""
@enable_hf_offline
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.pad_token = "</s>"
@pytest.fixture
def random_text(self):
# seed with random.seed(0) for reproducibility
random.seed(0)
# generate 20 rows of random text with "words" of between 2 and 10 characters and
# between 400 to 1200 characters per line
data = [
"".join(random.choices(string.ascii_lowercase, k=random.randint(2, 10)))
for _ in range(20)
] + [
" ".join(
random.choices(string.ascii_lowercase, k=random.randint(400, 1200))
)
for _ in range(20)
]
# Create an IterableDataset
def generator():
for text in data:
yield {"text": text}
return IterableDataset.from_generator(generator)
@pytest.mark.flaky(retries=1, delay=5)
@disable_hf_offline
def test_packing_stream_dataset(self):
# pylint: disable=duplicate-code
dataset = load_dataset(
"winglian/tiny-shakespeare",
streaming=True,
)["train"]
def test_packing_stream_dataset(self, tokenizer_huggyllama, random_text):
dataset = random_text
cfg = DictDefault(
{
@@ -55,15 +67,16 @@ class TestPretrainingPacking(unittest.TestCase):
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
self.tokenizer,
tokenizer_huggyllama,
cfg,
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
# pylint: disable=duplicate-code
original_bsz = cfg.micro_batch_size
train_dataset = wrap_pretraining_dataset(
dataset,
self.tokenizer,
tokenizer_huggyllama,
cfg,
ds_wrapper_partial,
max_tokens=cfg.sequence_len,
@@ -96,7 +109,3 @@ class TestPretrainingPacking(unittest.TestCase):
# [1, original_bsz * cfg.sequence_len]
# )
idx += 1
if __name__ == "__main__":
unittest.main()

View File

@@ -19,7 +19,7 @@ def reload_modules(hf_hub_offline):
importlib.reload(huggingface_hub.constants)
huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline
importlib.reload(datasets.config)
datasets.config.HF_HUB_OFFLINE = hf_hub_offline
setattr(datasets.config, "HF_HUB_OFFLINE", hf_hub_offline)
reset_sessions()