Data loader refactor (#2707)
* data loading refactor (wip) * updates * progress * pytest * pytest fix * lint * zero_first -> filelock, more simplifications * small simplification * import change * nit * lint * simplify dedup * couldnt resist * review comments WIP * continued wip * minor changes * fix; remove contrived test * further refactor * set default seed in pydantic config * lint * continued simplication * lint * renaming and nits * filelock tests * fix * fix * lint * remove nullable arg * remove unnecessary code * moving dataset save fn to shared module * remove debug print * matching var naming * fn name change * coderabbit comments * naming nit * fix test
This commit is contained in:
@@ -5,7 +5,6 @@ Additionally, this test suite includes tests for functions that indirectly call
|
||||
`deduplicate_and_log_datasets` during the execution of the preprocess command.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -14,8 +13,7 @@ from datasets import Dataset
|
||||
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -71,36 +69,14 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
self.expected_dataset = Dataset.from_dict(self.expected_data)
|
||||
|
||||
def test_deduplication(self):
|
||||
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=self.dataset)
|
||||
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=self.dataset)
|
||||
train_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)
|
||||
eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
dataset=self.dataset, dataset_name="eval"
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
|
||||
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
|
||||
|
||||
def test_datasets_are_none(self):
|
||||
# Test when both datasets are None
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=None, eval_dataset=None
|
||||
)
|
||||
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
|
||||
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
|
||||
|
||||
def test_only_train_is_none(self):
|
||||
# Test when only train_dataset is None
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=None, eval_dataset=self.dataset
|
||||
)
|
||||
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
|
||||
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
|
||||
|
||||
def test_only_eval_is_none(self):
|
||||
# Test when only eval_dataset is None
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=self.dataset, eval_dataset=None
|
||||
)
|
||||
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
|
||||
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
|
||||
|
||||
def test_exact_duplicates(self):
|
||||
# Test when datasets are exact duplicates
|
||||
duplicate_data = {
|
||||
@@ -115,8 +91,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
expected_dataset = Dataset.from_dict(expected_data)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
|
||||
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
|
||||
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
||||
eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
dataset=dataset, dataset_name="eval"
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
|
||||
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
|
||||
@@ -139,8 +117,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
expected_dataset = Dataset.from_dict(expected_data)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
|
||||
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
|
||||
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
||||
eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
dataset=dataset, dataset_name="eval"
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
|
||||
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
|
||||
@@ -169,8 +149,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=dataset, eval_dataset=dataset
|
||||
train_dataset, eval_dataset = deduplicate_and_log_datasets(
|
||||
dataset=dataset, other_dataset=dataset
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
|
||||
@@ -206,8 +186,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=dataset_train, eval_dataset=dataset_eval
|
||||
train_dataset, eval_dataset = deduplicate_and_log_datasets(
|
||||
dataset=dataset_train, other_dataset=dataset_eval
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
|
||||
@@ -245,7 +225,9 @@ class TestDeduplicateRLDataset:
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||
patch(
|
||||
"axolotl.utils.data.rl.load_dataset_with_config"
|
||||
) as mock_load_dataset,
|
||||
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
@@ -255,7 +237,8 @@ class TestDeduplicateRLDataset:
|
||||
]
|
||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
# Verify that the dataset has been deduplicated
|
||||
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
||||
@@ -269,7 +252,9 @@ class TestDeduplicateRLDataset:
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||
patch(
|
||||
"axolotl.utils.data.rl.load_dataset_with_config"
|
||||
) as mock_load_dataset,
|
||||
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
@@ -279,9 +264,10 @@ class TestDeduplicateRLDataset:
|
||||
]
|
||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||
|
||||
cfg.dataset_exact_deduplication = False
|
||||
# Load the dataset without deduplication
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
cfg.dataset_exact_deduplication = False
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
# Verify that the dataset retains duplicates
|
||||
assert (
|
||||
@@ -335,7 +321,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
)
|
||||
|
||||
# Prepare dataset using the prepare_dataset function
|
||||
train_dataset, _, _, _ = prepare_dataset(
|
||||
train_dataset, _, _, _ = prepare_datasets(
|
||||
self.cfg_1,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
@@ -362,7 +348,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
)
|
||||
|
||||
# Prepare dataset using the prepare_dataset function
|
||||
_, eval_dataset, _, _ = prepare_dataset(
|
||||
_, eval_dataset, _, _ = prepare_datasets(
|
||||
self.cfg_1,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
@@ -389,7 +375,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
)
|
||||
|
||||
# Prepare dataset using the prepare_dataset function
|
||||
train_dataset, eval_dataset, _, _ = prepare_dataset(
|
||||
train_dataset, eval_dataset, _, _ = prepare_datasets(
|
||||
self.cfg_1,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
@@ -428,41 +414,8 @@ class TestWrongCollisions(unittest.TestCase):
|
||||
self.eval_dataset = Dataset.from_dict(self.eval_data)
|
||||
self.dataset = Dataset.from_dict(self.dataset_data)
|
||||
|
||||
@patch(
|
||||
"axolotl.utils.data.utils.sha256",
|
||||
side_effect=lambda x: (
|
||||
hashlib.sha256("forced_collision_hash".encode("utf-8")).hexdigest()
|
||||
if "sample 5" in x
|
||||
else hashlib.sha256(x.encode("utf-8")).hexdigest()
|
||||
),
|
||||
)
|
||||
def test_deduplication_wrong_collision_train_eval(self, _mock_sha256):
|
||||
dedup_train, dedup_eval, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=self.train_dataset, eval_dataset=self.eval_dataset
|
||||
)
|
||||
self.assertEqual(
|
||||
len(dedup_train),
|
||||
2,
|
||||
"train dataset should not deduplicate rows with forced hash collisions but different labels.",
|
||||
)
|
||||
self.assertEqual(
|
||||
len(dedup_eval),
|
||||
2,
|
||||
"Eval dataset should not deduplicate rows with forced hash collisions but different labels.",
|
||||
)
|
||||
self.assertEqual(
|
||||
len(dedup_eval),
|
||||
len(self.eval_dataset),
|
||||
"The output eval dataset should have the same number of rows as the input eval dataset.",
|
||||
)
|
||||
self.assertEqual(
|
||||
str(dedup_eval),
|
||||
str(self.eval_dataset),
|
||||
"The string representation of the output eval dataset should be identical to the input eval dataset.",
|
||||
)
|
||||
|
||||
def test_deduplication_dataset_only(self):
|
||||
_, _, dedup_dataset = deduplicate_and_log_datasets(dataset=self.dataset)
|
||||
dedup_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)
|
||||
self.assertEqual(
|
||||
len(dedup_dataset), 3, "Dataset should have all original values"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user