Data loader refactor (#2707)

* data loading refactor (wip)

* updates

* progress

* pytest

* pytest fix

* lint

* zero_first -> filelock, more simplifications

* small simplification

* import change

* nit

* lint

* simplify dedup

* couldnt resist

* review comments WIP

* continued wip

* minor changes

* fix; remove contrived test

* further refactor

* set default seed in pydantic config

* lint

* continued simplication

* lint

* renaming and nits

* filelock tests

* fix

* fix

* lint

* remove nullable arg

* remove unnecessary code

* moving dataset save fn to shared module

* remove debug print

* matching var naming

* fn name change

* coderabbit comments

* naming nit

* fix test
This commit is contained in:
Dan Saunders
2025-06-10 19:53:07 -04:00
committed by GitHub
parent 52a0452acb
commit 00cda8cc70
62 changed files with 2125 additions and 1436 deletions

View File

@@ -5,7 +5,6 @@ Additionally, this test suite includes tests for functions that indirectly call
`deduplicate_and_log_datasets` during the execution of the preprocess command.
"""
import hashlib
import unittest
from unittest.mock import patch
@@ -14,8 +13,7 @@ from datasets import Dataset
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets
from axolotl.utils.dict import DictDefault
@@ -71,36 +69,14 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
self.expected_dataset = Dataset.from_dict(self.expected_data)
def test_deduplication(self):
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=self.dataset)
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=self.dataset)
train_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)
eval_dataset, _ = deduplicate_and_log_datasets(
dataset=self.dataset, dataset_name="eval"
)
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
def test_datasets_are_none(self):
# Test when both datasets are None
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=None, eval_dataset=None
)
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
def test_only_train_is_none(self):
# Test when only train_dataset is None
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=None, eval_dataset=self.dataset
)
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
def test_only_eval_is_none(self):
# Test when only eval_dataset is None
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=self.dataset, eval_dataset=None
)
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
def test_exact_duplicates(self):
# Test when datasets are exact duplicates
duplicate_data = {
@@ -115,8 +91,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
expected_dataset = Dataset.from_dict(expected_data)
# Run deduplication
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
eval_dataset, _ = deduplicate_and_log_datasets(
dataset=dataset, dataset_name="eval"
)
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
@@ -139,8 +117,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
expected_dataset = Dataset.from_dict(expected_data)
# Run deduplication
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
eval_dataset, _ = deduplicate_and_log_datasets(
dataset=dataset, dataset_name="eval"
)
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
@@ -169,8 +149,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
# Run deduplication
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=dataset, eval_dataset=dataset
train_dataset, eval_dataset = deduplicate_and_log_datasets(
dataset=dataset, other_dataset=dataset
)
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
@@ -206,8 +186,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
# Run deduplication
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=dataset_train, eval_dataset=dataset_eval
train_dataset, eval_dataset = deduplicate_and_log_datasets(
dataset=dataset_train, other_dataset=dataset_eval
)
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
@@ -245,7 +225,9 @@ class TestDeduplicateRLDataset:
# pylint: disable=duplicate-code
with (
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
patch(
"axolotl.utils.data.rl.load_dataset_with_config"
) as mock_load_dataset,
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
@@ -255,7 +237,8 @@ class TestDeduplicateRLDataset:
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
train_dataset, _ = load_prepare_preference_datasets(cfg)
tokenizer = load_tokenizer(cfg)
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
# Verify that the dataset has been deduplicated
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
@@ -269,7 +252,9 @@ class TestDeduplicateRLDataset:
):
# pylint: disable=duplicate-code
with (
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
patch(
"axolotl.utils.data.rl.load_dataset_with_config"
) as mock_load_dataset,
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
@@ -279,9 +264,10 @@ class TestDeduplicateRLDataset:
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication
train_dataset, _ = load_prepare_preference_datasets(cfg)
cfg.dataset_exact_deduplication = False
tokenizer = load_tokenizer(cfg)
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
# Verify that the dataset retains duplicates
assert (
@@ -335,7 +321,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
)
# Prepare dataset using the prepare_dataset function
train_dataset, _, _, _ = prepare_dataset(
train_dataset, _, _, _ = prepare_datasets(
self.cfg_1,
tokenizer,
processor=processor,
@@ -362,7 +348,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
)
# Prepare dataset using the prepare_dataset function
_, eval_dataset, _, _ = prepare_dataset(
_, eval_dataset, _, _ = prepare_datasets(
self.cfg_1,
tokenizer,
processor=processor,
@@ -389,7 +375,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
)
# Prepare dataset using the prepare_dataset function
train_dataset, eval_dataset, _, _ = prepare_dataset(
train_dataset, eval_dataset, _, _ = prepare_datasets(
self.cfg_1,
tokenizer,
processor=processor,
@@ -428,41 +414,8 @@ class TestWrongCollisions(unittest.TestCase):
self.eval_dataset = Dataset.from_dict(self.eval_data)
self.dataset = Dataset.from_dict(self.dataset_data)
@patch(
"axolotl.utils.data.utils.sha256",
side_effect=lambda x: (
hashlib.sha256("forced_collision_hash".encode("utf-8")).hexdigest()
if "sample 5" in x
else hashlib.sha256(x.encode("utf-8")).hexdigest()
),
)
def test_deduplication_wrong_collision_train_eval(self, _mock_sha256):
dedup_train, dedup_eval, _ = deduplicate_and_log_datasets(
train_dataset=self.train_dataset, eval_dataset=self.eval_dataset
)
self.assertEqual(
len(dedup_train),
2,
"train dataset should not deduplicate rows with forced hash collisions but different labels.",
)
self.assertEqual(
len(dedup_eval),
2,
"Eval dataset should not deduplicate rows with forced hash collisions but different labels.",
)
self.assertEqual(
len(dedup_eval),
len(self.eval_dataset),
"The output eval dataset should have the same number of rows as the input eval dataset.",
)
self.assertEqual(
str(dedup_eval),
str(self.eval_dataset),
"The string representation of the output eval dataset should be identical to the input eval dataset.",
)
def test_deduplication_dataset_only(self):
_, _, dedup_dataset = deduplicate_and_log_datasets(dataset=self.dataset)
dedup_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)
self.assertEqual(
len(dedup_dataset), 3, "Dataset should have all original values"
)