use offline for precached stream dataset (#2453)
This commit is contained in:
@@ -9,9 +9,8 @@ import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
|
||||
from constants import ALPACA_MESSAGES_CONFIG_REVISION
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
from utils import enable_hf_offline
|
||||
|
||||
from axolotl.utils.config import normalize_config
|
||||
@@ -216,13 +215,12 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
|
||||
|
||||
|
||||
class TestDeduplicateRLDataset(unittest.TestCase):
|
||||
class TestDeduplicateRLDataset:
|
||||
"""Test a configured dataloader with deduplication."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
||||
self.cfg = DictDefault(
|
||||
@pytest.fixture
|
||||
def cfg(self):
|
||||
fixture = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
@@ -235,28 +233,59 @@ class TestDeduplicateRLDataset(unittest.TestCase):
|
||||
],
|
||||
}
|
||||
)
|
||||
yield fixture
|
||||
|
||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||
@enable_hf_offline
|
||||
def test_load_with_deduplication(self):
|
||||
def test_load_with_deduplication(
|
||||
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
|
||||
):
|
||||
"""Verify that loading with deduplication removes duplicates."""
|
||||
|
||||
# Load the dataset using the deduplication setting
|
||||
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
) as mock_load_dataset,
|
||||
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.side_effect = [
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
]
|
||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||
|
||||
# Verify that the dataset has been deduplicated
|
||||
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
|
||||
def test_load_without_deduplication(self):
|
||||
"""Verify that loading without deduplication retains duplicates."""
|
||||
self.cfg.dataset_exact_deduplication = False
|
||||
# Load the dataset without deduplication
|
||||
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
|
||||
# Verify that the dataset has been deduplicated
|
||||
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
||||
|
||||
# Verify that the dataset retains duplicates
|
||||
assert (
|
||||
len(train_dataset) == 1800 * 2
|
||||
), "Dataset deduplication occurred when it should not have"
|
||||
@enable_hf_offline
|
||||
def test_load_without_deduplication(
|
||||
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
) as mock_load_dataset,
|
||||
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.side_effect = [
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
||||
]
|
||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||
|
||||
cfg.dataset_exact_deduplication = False
|
||||
# Load the dataset without deduplication
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
|
||||
# Verify that the dataset retains duplicates
|
||||
assert (
|
||||
len(train_dataset) == 1800 * 2
|
||||
), "Dataset deduplication occurred when it should not have"
|
||||
|
||||
|
||||
class TestDeduplicateNonRL(unittest.TestCase):
|
||||
@@ -264,8 +293,6 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
|
||||
@enable_hf_offline
|
||||
def setUp(self) -> None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
||||
self.cfg_1 = DictDefault(
|
||||
{
|
||||
"base_model": "huggyllama/llama-7b",
|
||||
|
||||
Reference in New Issue
Block a user