use offline for precached stream dataset (#2453)

This commit is contained in:
Wing Lian
2025-03-28 23:39:09 -04:00
committed by GitHub
parent e46239f8d3
commit c49682132b
8 changed files with 179 additions and 124 deletions

View File

@@ -9,9 +9,8 @@ import unittest
from unittest.mock import patch
import pytest
from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
from constants import ALPACA_MESSAGES_CONFIG_REVISION
from datasets import Dataset
from transformers import AutoTokenizer
from utils import enable_hf_offline
from axolotl.utils.config import normalize_config
@@ -216,13 +215,12 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
class TestDeduplicateRLDataset(unittest.TestCase):
class TestDeduplicateRLDataset:
"""Test a configured dataloader with deduplication."""
def setUp(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
self.cfg = DictDefault(
@pytest.fixture
def cfg(self):
fixture = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
@@ -235,28 +233,59 @@ class TestDeduplicateRLDataset(unittest.TestCase):
],
}
)
yield fixture
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_load_with_deduplication(self):
def test_load_with_deduplication(
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
):
"""Verify that loading with deduplication removes duplicates."""
# Load the dataset using the deduplication setting
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
# pylint: disable=duplicate-code
with (
patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
# Verify that the dataset has been deduplicated
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
train_dataset, _ = load_prepare_preference_datasets(cfg)
def test_load_without_deduplication(self):
"""Verify that loading without deduplication retains duplicates."""
self.cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
# Verify that the dataset has been deduplicated
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
# Verify that the dataset retains duplicates
assert (
len(train_dataset) == 1800 * 2
), "Dataset deduplication occurred when it should not have"
@enable_hf_offline
def test_load_without_deduplication(
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
):
# pylint: disable=duplicate-code
with (
patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication
train_dataset, _ = load_prepare_preference_datasets(cfg)
# Verify that the dataset retains duplicates
assert (
len(train_dataset) == 1800 * 2
), "Dataset deduplication occurred when it should not have"
class TestDeduplicateNonRL(unittest.TestCase):
@@ -264,8 +293,6 @@ class TestDeduplicateNonRL(unittest.TestCase):
@enable_hf_offline
def setUp(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
self.cfg_1 = DictDefault(
{
"base_model": "huggyllama/llama-7b",