diff --git a/docs/config.qmd b/docs/config.qmd
index 04e278e2d..bc3730095 100644
--- a/docs/config.qmd
+++ b/docs/config.qmd
@@ -162,6 +162,9 @@ datasets:
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true
+Deduplicates datasets and test_datasets with identical entries.
+dataset_exact_deduplication: true
+
# A list of one or more datasets to eval the model with.
# You can use either test_datasets, or val_set_size, but not both.
test_datasets:
diff --git a/examples/llama-3/lora-1b-deduplicate-dpo.yml b/examples/llama-3/lora-1b-deduplicate-dpo.yml
new file mode 100644
index 000000000..35a0260ca
--- /dev/null
+++ b/examples/llama-3/lora-1b-deduplicate-dpo.yml
@@ -0,0 +1,95 @@
+base_model: meta-llama/Llama-3.2-1B
+model_type: LlamaForCausalLM
+tokenizer_type: AutoTokenizer
+
+load_in_8bit: true
+load_in_4bit: false
+strict: false
+
+chat_template: llama3
+rl: dpo
+datasets:
+ - path: fozziethebeat/alpaca_messages_2k_dpo_test
+ type: chat_template.default
+ field_messages: conversation
+ field_chosen: chosen
+ field_rejected: rejected
+ message_field_role: role
+ message_field_content: content
+ roles:
+ system:
+ - system
+ user:
+ - user
+ assistant:
+ - assistant
+ - path: fozziethebeat/alpaca_messages_2k_dpo_test
+ type: chat_template.default
+ field_messages: conversation
+ field_chosen: chosen
+ field_rejected: rejected
+ message_field_role: role
+ message_field_content: content
+ roles:
+ system:
+ - system
+ user:
+ - user
+ assistant:
+ - assistant
+
+dataset_exact_deduplication: true
+dataset_prepared_path:
+val_set_size: 0
+output_dir: ./outputs/lora-out
+
+sequence_len: 4096
+sample_packing: false
+pad_to_sequence_len: true
+
+adapter: lora
+lora_model_dir:
+lora_r: 32
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_linear: true
+lora_fan_in_fan_out:
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_name:
+wandb_log_model:
+
+gradient_accumulation_steps: 4
+micro_batch_size: 2
+num_epochs: 4
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.0002
+
+train_on_inputs: false
+group_by_length: false
+bf16: auto
+fp16:
+tf32: false
+
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+s2_attention:
+
+warmup_steps: 10
+evals_per_epoch: 4
+eval_table_size:
+eval_max_new_tokens: 128
+saves_per_epoch: 1
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
diff --git a/examples/llama-3/lora-1b-deduplicate-sft.yml b/examples/llama-3/lora-1b-deduplicate-sft.yml
new file mode 100644
index 000000000..c07d5f8ff
--- /dev/null
+++ b/examples/llama-3/lora-1b-deduplicate-sft.yml
@@ -0,0 +1,76 @@
+base_model: meta-llama/Llama-3.2-1B
+model_type: LlamaForCausalLM
+tokenizer_type: AutoTokenizer
+
+load_in_8bit: true
+load_in_4bit: false
+strict: false
+
+datasets:
+ - path: mhenrichsen/alpaca_2k_test
+ type: alpaca
+ - path: mhenrichsen/alpaca_2k_test
+ type: alpaca
+dataset_prepared_path:
+val_set_size: 0.0
+output_dir: ./outputs/lora-out
+
+dataset_exact_deduplication: true
+test_value: true
+
+sequence_len: 4096
+sample_packing: true
+eval_sample_packing: false
+pad_to_sequence_len: true
+
+adapter: lora
+lora_model_dir:
+lora_r: 32
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_linear: true
+lora_fan_in_fan_out:
+lora_modules_to_save:
+ - embed_tokens
+ - lm_head
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_name:
+wandb_log_model:
+
+gradient_accumulation_steps: 4
+micro_batch_size: 2
+num_epochs: 4
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.0002
+
+train_on_inputs: false
+group_by_length: false
+bf16: auto
+fp16:
+tf32: false
+
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+s2_attention:
+
+warmup_steps: 10
+evals_per_epoch: 4
+eval_table_size:
+eval_max_new_tokens: 128
+saves_per_epoch: 1
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
+ pad_token: <|end_of_text|>
diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py
index 7c8db7ce8..86cc30a40 100644
--- a/src/axolotl/cli/__init__.py
+++ b/src/axolotl/cli/__init__.py
@@ -139,7 +139,7 @@ def check_remote_config(config: Union[str, Path]):
with open(output_path, "wb") as file:
file.write(content)
LOG.info(
- f"Using the following config obtained from {config}:\n\n{content.decode('utf-8')}\n"
+ f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n"
)
return output_path
diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
index eeadfd067..1ac7efbfa 100644
--- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
+++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
@@ -624,6 +624,7 @@ class AxolotlInputConfig(
json_schema_extra={"description": "streaming dataset to use for pretraining"},
)
dataset_processes: Optional[int] = Field(default=os.cpu_count())
+ dataset_exact_deduplication: Optional[bool] = None
dataset_keep_in_memory: Optional[bool] = None
dataloader_pin_memory: Optional[bool] = None
dataloader_num_workers: Optional[int] = None
diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py
index cf1226175..edb72f186 100644
--- a/src/axolotl/utils/data/rl.py
+++ b/src/axolotl/utils/data/rl.py
@@ -13,7 +13,7 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.prompt_strategies.kto import load as load_kto
from axolotl.prompt_strategies.orpo import load as load_orpo
-from axolotl.utils.data.utils import md5
+from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.models import load_tokenizer
@@ -208,4 +208,9 @@ def load_prepare_dpo_datasets(cfg):
if eval_dataset and not eval_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
+ if cfg.dataset_exact_deduplication:
+ train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
+ train_dataset=train_dataset, eval_dataset=eval_dataset
+ )
+
return train_dataset, eval_dataset
diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py
index b72e0a2b3..0bee4dd5c 100644
--- a/src/axolotl/utils/data/sft.py
+++ b/src/axolotl/utils/data/sft.py
@@ -44,7 +44,7 @@ from axolotl.prompters import (
UnsupportedPrompter,
)
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
-from axolotl.utils.data.utils import md5
+from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.trainer import (
@@ -136,8 +136,9 @@ def prepare_dataset(cfg, tokenizer, processor=None):
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
+ if cfg.dataset_exact_deduplication:
+ LOG.info("Deduplication not available for pretrained datasets")
return train_dataset, eval_dataset, cfg.max_steps, prompters
-
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
if total_eval_steps == 0:
@@ -178,7 +179,7 @@ def load_tokenized_prepared_datasets(
+ "|".join(
sorted(
[
- f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}"
+ f"{d.path}: {d.type}: {d.shards}: {d.conversation}{d.split}"
for d in cfg_datasets
]
)
@@ -584,7 +585,8 @@ def load_prepare_datasets(
)
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)
-
+ if cfg.dataset_exact_deduplication:
+ _, _, dataset = deduplicate_and_log_datasets(dataset=dataset)
dataset = dataset.train_test_split(
test_size=val_set_size,
shuffle=False,
@@ -596,12 +598,17 @@ def load_prepare_datasets(
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
elif split == "test":
+ if cfg.dataset_exact_deduplication:
+ _, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
+ else:
+ eval_dataset = dataset
train_dataset = None
- eval_dataset = dataset
else:
- train_dataset = dataset
+ if cfg.dataset_exact_deduplication:
+ train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
+ else:
+ train_dataset = dataset
eval_dataset = None
-
return train_dataset, eval_dataset, prompters
diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py
index e05701e7b..56bcddd8e 100644
--- a/src/axolotl/utils/data/utils.py
+++ b/src/axolotl/utils/data/utils.py
@@ -1,6 +1,11 @@
"""data handling helpers"""
import hashlib
+import logging
+
+from datasets import Dataset
+
+LOG = logging.getLogger("axolotl")
def md5(to_hash: str, encoding: str = "utf-8") -> str:
@@ -8,3 +13,96 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
except TypeError:
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
+
+
+def sha256(to_hash: str, encoding: str = "utf-8") -> str:
+ return hashlib.sha256(to_hash.encode(encoding)).hexdigest()
+
+
+def deduplicate_dataset(
+ dataset: Dataset, seen_hashes: dict[str, list[int]], other_dataset: Dataset = None
+) -> Dataset:
+ unique_indices = []
+
+ for idx, row in enumerate(dataset):
+ row_hash = sha256(str(row)) # Using SHA256 for collision resistance.
+ if row_hash not in seen_hashes:
+ seen_hashes[row_hash] = [idx]
+ unique_indices.append(idx)
+ else:
+ # Check for collision by looking up the original dataset indices
+ original_indices = seen_hashes[row_hash]
+ is_duplicate = False
+ for original_idx in original_indices:
+ if (
+ not idx == original_idx
+ and original_idx < len(dataset)
+ and str(dataset[original_idx]) == str(row)
+ ):
+ is_duplicate = True
+ break
+ # Check in the other dataset if provided
+ if other_dataset is not None:
+ if original_idx < len(other_dataset) and str(
+ other_dataset[original_idx]
+ ) == str(row):
+ is_duplicate = True
+ break
+ if not is_duplicate:
+ seen_hashes[row_hash].append(idx)
+ unique_indices.append(idx)
+ continue
+ return dataset.select(unique_indices)
+
+
+def deduplicate_and_log_datasets(
+ *,
+ train_dataset: Dataset = None,
+ eval_dataset: Dataset = None,
+ dataset: Dataset = None,
+) -> tuple[Dataset, Dataset, Dataset]:
+ """
+ Deduplicates train, eval, and an optional dataset if provided, logging original and new sizes.
+
+ Returns:
+ tuple: Deduplicated train, eval, and additional datasets.
+ """
+ seen_hashes: dict[str, list[int]] = {}
+
+ # Handle cases where datasets are None
+ if train_dataset is not None:
+ LOG.info(
+ f"Starting deduplication for train dataset. Original size: {len(train_dataset)}"
+ )
+ train_dataset = deduplicate_dataset(
+ dataset=train_dataset, seen_hashes=seen_hashes
+ )
+ LOG.info(
+ f"Deduplication complete for train dataset. New size: {len(train_dataset)}"
+ )
+ else:
+ LOG.info("Train dataset is None. Skipping deduplication.")
+
+ if eval_dataset is not None:
+ LOG.info(
+ f"Starting deduplication for eval dataset. Original size: {len(eval_dataset)}"
+ )
+ eval_dataset = deduplicate_dataset(
+ dataset=eval_dataset, seen_hashes=seen_hashes, other_dataset=train_dataset
+ )
+ LOG.info(
+ f"Deduplication complete for eval dataset. New size: {len(eval_dataset)}"
+ )
+ else:
+ LOG.info("Eval dataset is None. Skipping deduplication.")
+
+ if dataset is not None and (eval_dataset is None and train_dataset is None):
+ LOG.info(
+ f"Starting deduplication for combined dataset. Original size: {len(dataset)}"
+ )
+ dataset = deduplicate_dataset(dataset=dataset, seen_hashes=seen_hashes)
+ LOG.info(
+ f"Deduplication complete for combined dataset. New size: {len(dataset)}"
+ )
+
+ return train_dataset, eval_dataset, dataset
diff --git a/tests/constants.py b/tests/constants.py
new file mode 100644
index 000000000..e024e6920
--- /dev/null
+++ b/tests/constants.py
@@ -0,0 +1,32 @@
+# constants.py
+"""
+This module contains constants and configuration dictionaries used for
+datasets and other utilities in the Axolotl project, specifically for testing.
+"""
+# Configuration for Alpaca Messages Dataset
+ALPACA_MESSAGES_CONFIG_OG = {
+ "path": "fozziethebeat/alpaca_messages_2k_dpo_test",
+ "type": "chat_template.default",
+ "chat_template": "llama3",
+ "field_messages": "conversation",
+ "field_chosen": "chosen",
+ "field_rejected": "rejected",
+ "message_field_role": "role",
+ "message_field_content": "content",
+ "roles": {
+ "system": ["system"],
+ "user": ["user"],
+ "assistant": ["assistant"],
+ },
+}
+
+# Revision configuration extending the original
+ALPACA_MESSAGES_CONFIG_REVISION = ALPACA_MESSAGES_CONFIG_OG.copy()
+ALPACA_MESSAGES_CONFIG_REVISION["revision"] = "ea82cff"
+
+
+SPECIAL_TOKENS = {
+ "bos_token": "",
+ "eos_token": "",
+ "unk_token": "",
+}
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index e87f19cc7..f3bed00fd 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -7,6 +7,11 @@ import tempfile
import unittest
from pathlib import Path
+from constants import (
+ ALPACA_MESSAGES_CONFIG_OG,
+ ALPACA_MESSAGES_CONFIG_REVISION,
+ SPECIAL_TOKENS,
+)
from datasets import Dataset
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
@@ -21,13 +26,7 @@ class TestDatasetPreparation(unittest.TestCase):
def setUp(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
- self.tokenizer.add_special_tokens(
- {
- "bos_token": "",
- "eos_token": "",
- "unk_token": "",
- }
- )
+ self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
# Alpaca dataset.
self.dataset = Dataset.from_list(
[
@@ -277,23 +276,7 @@ class TestDatasetPreparation(unittest.TestCase):
"sequence_len": 1024,
"rl": "dpo",
"chat_template": "llama3",
- "datasets": [
- {
- "path": "fozziethebeat/alpaca_messages_2k_dpo_test",
- "type": "chat_template.default",
- "chat_template": "llama3",
- "field_messages": "conversation",
- "field_chosen": "chosen",
- "field_rejected": "rejected",
- "message_field_role": "role",
- "message_field_content": "content",
- "roles": {
- "system": ["system"],
- "user": ["user"],
- "assistant": ["assistant"],
- },
- }
- ],
+ "datasets": [ALPACA_MESSAGES_CONFIG_OG],
}
)
@@ -342,24 +325,7 @@ class TestDatasetPreparation(unittest.TestCase):
"sequence_len": 1024,
"rl": "dpo",
"chat_template": "llama3",
- "datasets": [
- {
- "path": "fozziethebeat/alpaca_messages_2k_dpo_test",
- "type": "chat_template.default",
- "chat_template": "llama3",
- "revision": "ea82cff",
- "field_messages": "conversation",
- "field_chosen": "chosen",
- "field_rejected": "rejected",
- "message_field_role": "role",
- "message_field_content": "content",
- "roles": {
- "system": ["system"],
- "user": ["user"],
- "assistant": ["assistant"],
- },
- }
- ],
+ "datasets": [ALPACA_MESSAGES_CONFIG_REVISION],
}
)
diff --git a/tests/test_exact_deduplication.py b/tests/test_exact_deduplication.py
new file mode 100644
index 000000000..2ac6415be
--- /dev/null
+++ b/tests/test_exact_deduplication.py
@@ -0,0 +1,433 @@
+"""
+Test suite for functions in the axolotl.utils.data.utils module, focusing on the deduplicate_and_log_datasets function.
+
+Additionally, this test suite includes tests for functions that indirectly call deduplicate_and_log_datasets during the execution of the preprocess command.
+"""
+import hashlib
+import unittest
+from unittest.mock import patch
+
+from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
+from datasets import Dataset
+from transformers import AutoTokenizer
+
+from axolotl.utils.data import prepare_dataset
+from axolotl.utils.data.rl import load_prepare_dpo_datasets
+from axolotl.utils.data.utils import deduplicate_and_log_datasets
+from axolotl.utils.dict import DictDefault
+from axolotl.utils.models import load_processor, load_tokenizer
+
+
+def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
+ """
+ Validates deduplication results and size consistency.
+
+ Parameters:
+ - actual_dataset: Deduplicated dataset.
+ - expected_dataset: Expected dataset.
+ - dataset_name: Name of the dataset (e.g., 'train' or 'eval').
+
+ Asserts:
+ - Datasets match in content.
+ - Dataset size matches unique row count.
+ """
+ # Convert datasets to sets of tuples for unordered comparison
+ actual_rows = set(tuple(row.values()) for row in actual_dataset)
+ expected_rows = set(tuple(row.values()) for row in expected_dataset)
+
+ # Verify deduplication correctness
+ assert actual_rows == expected_rows, f"Mismatch in {dataset_name} dataset"
+
+ # Verify size consistency
+ assert len(actual_rows) == len(
+ actual_dataset
+ ), f"Size mismatch in {dataset_name} dataset after deduplication"
+
+
+class TestDeduplicateIndividualFunctions(unittest.TestCase):
+ """
+ test class for deduplication function in data utils
+ """
+
+ def setUp(self):
+ # Sample data with duplicates
+ self.data = {
+ "column1": ["apple", "banana", "apple", "orange", "banana"],
+ "column2": [1, 2, 1, 3, 2],
+ "column3": ["red", "yellow", "red", "orange", "yellow"],
+ }
+
+ # Expected result after deduplication
+ self.expected_data = {
+ "column1": ["apple", "banana", "orange"],
+ "column2": [1, 2, 3],
+ "column3": ["red", "yellow", "orange"],
+ }
+
+ # Convert to Dataset format
+ self.dataset = Dataset.from_dict(self.data)
+ self.expected_dataset = Dataset.from_dict(self.expected_data)
+
+ def test_deduplication(self):
+ train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=self.dataset)
+ _, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=self.dataset)
+
+ verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
+ verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
+
+ def test_datasets_are_none(self):
+ # Test when both datasets are None
+ train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
+ train_dataset=None, eval_dataset=None
+ )
+ self.assertIsNone(train_dataset, "Expected train_dataset to be None")
+ self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
+
+ def test_only_train_is_none(self):
+ # Test when only train_dataset is None
+ train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
+ train_dataset=None, eval_dataset=self.dataset
+ )
+ self.assertIsNone(train_dataset, "Expected train_dataset to be None")
+ verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
+
+ def test_only_eval_is_none(self):
+ # Test when only eval_dataset is None
+ train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
+ train_dataset=self.dataset, eval_dataset=None
+ )
+ self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
+ verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
+
+ def test_exact_duplicates(self):
+ # Test when datasets are exact duplicates
+ duplicate_data = {
+ "column1": ["apple", "apple", "apple"],
+ "column2": [1, 1, 1],
+ "column3": ["red", "red", "red"],
+ }
+ expected_data = {"column1": ["apple"], "column2": [1], "column3": ["red"]}
+
+ # Convert to Dataset format
+ dataset = Dataset.from_dict(duplicate_data)
+ expected_dataset = Dataset.from_dict(expected_data)
+
+ # Run deduplication
+ train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
+ _, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
+
+ verify_deduplication(train_dataset, expected_dataset, "train_dataset")
+ verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
+
+ def test_partial_duplicates(self):
+ # Test when only part of the dataset is a duplicate
+ partial_duplicate_data = {
+ "column1": ["apple", "banana", "apple"],
+ "column2": [1, 2, 1],
+ "column3": ["red", "yellow", "red"],
+ }
+ expected_data = {
+ "column1": ["apple", "banana"],
+ "column2": [1, 2],
+ "column3": ["red", "yellow"],
+ }
+
+ # Convert to Dataset format
+ dataset = Dataset.from_dict(partial_duplicate_data)
+ expected_dataset = Dataset.from_dict(expected_data)
+
+ # Run deduplication
+ train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
+ _, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
+
+ verify_deduplication(train_dataset, expected_dataset, "train_dataset")
+ verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
+
+ def test_combined_duplicates_empty(self):
+ # Test when only part of the dataset is a duplicate
+ partial_duplicate_data = {
+ "column1": ["apple", "banana", "apple"],
+ "column2": [1, 2, 1],
+ "column3": ["red", "yellow", "red"],
+ }
+ expected_data_train = {
+ "column1": ["apple", "banana"],
+ "column2": [1, 2],
+ "column3": ["red", "yellow"],
+ }
+ expected_data_eval = {
+ "column1": [],
+ "column2": [],
+ "column3": [],
+ }
+
+ # Convert to Dataset format
+ dataset = Dataset.from_dict(partial_duplicate_data)
+ expected_dataset_train = Dataset.from_dict(expected_data_train)
+ expected_dataset_eval = Dataset.from_dict(expected_data_eval)
+
+ # Run deduplication
+ train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
+ train_dataset=dataset, eval_dataset=dataset
+ )
+
+ verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
+ verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
+
+ def test_combined_duplicates_one(self):
+ # Test when only part of the dataset is a duplicate
+ partial_duplicate_data_train = {
+ "column1": ["apple", "banana", "apple"],
+ "column2": [1, 2, 1],
+ "column3": ["red", "yellow", "red"],
+ }
+ partial_duplicate_data_eval = {
+ "column1": ["apple", "orange", "apple"],
+ "column2": [1, 2, 1],
+ "column3": ["red", "orange", "red"],
+ }
+ expected_data_train = {
+ "column1": ["apple", "banana"],
+ "column2": [1, 2],
+ "column3": ["red", "yellow"],
+ }
+ expected_data_eval = {
+ "column1": ["orange"],
+ "column2": [2],
+ "column3": ["orange"],
+ }
+
+ # Convert to Dataset format
+ dataset_train = Dataset.from_dict(partial_duplicate_data_train)
+ dataset_eval = Dataset.from_dict(partial_duplicate_data_eval)
+ expected_dataset_train = Dataset.from_dict(expected_data_train)
+ expected_dataset_eval = Dataset.from_dict(expected_data_eval)
+
+ # Run deduplication
+ train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
+ train_dataset=dataset_train, eval_dataset=dataset_eval
+ )
+
+ verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
+ verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
+
+
+class TestDeduplicateRLDataset(unittest.TestCase):
+ """Test a configured dataloader with deduplication."""
+
+ def setUp(self) -> None:
+ self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
+ self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
+ self.cfg = DictDefault(
+ {
+ "tokenizer_config": "huggyllama/llama-7b",
+ "sequence_len": 1024,
+ "rl": "dpo",
+ "chat_template": "llama3",
+ "dataset_exact_deduplication": True,
+ "datasets": [
+ ALPACA_MESSAGES_CONFIG_REVISION,
+ ALPACA_MESSAGES_CONFIG_REVISION,
+ ],
+ }
+ )
+
+ def test_load_with_deduplication(self):
+ """Verify that loading with deduplication removes duplicates."""
+
+ # Load the dataset using the deduplication setting
+ train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
+
+ # Verify that the dataset has been deduplicated
+ assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
+
+ def test_load_without_deduplication(self):
+ """Verify that loading without deduplication retains duplicates."""
+ self.cfg.dataset_exact_deduplication = False
+ # Load the dataset without deduplication
+ train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
+
+ # Verify that the dataset retains duplicates
+ assert (
+ len(train_dataset) == 1800 * 2
+ ), "Dataset deduplication occurred when it should not have"
+
+
+class TestDeduplicateNonRL(unittest.TestCase):
+ """Test prepare_dataset function with different configurations."""
+
+ def setUp(self) -> None:
+ self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
+ self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
+ self.cfg_1 = DictDefault(
+ {
+ "tokenizer_config": "huggyllama/llama-7b",
+ "sequence_len": 1024,
+ "dataset_exact_deduplication": True,
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "val_set_size": 0.0,
+ "gradient_accumulation_steps": 4,
+ "batch_size": 10,
+ "micro_batch_size": 10,
+ "num_epochs": 1,
+ }
+ )
+
+ def test_prepare_dataset_with_deduplication_train(self):
+ """Verify that prepare_dataset function processes the dataset correctly with deduplication."""
+ self.cfg_1.dataset_exact_deduplication = True
+
+ # Load tokenizer and processor
+ tokenizer = load_tokenizer(self.cfg_1)
+ processor = (
+ load_processor(self.cfg_1, tokenizer=tokenizer)
+ if self.cfg_1.processor_type
+ else None
+ )
+
+ # Prepare dataset using the prepare_dataset function
+ train_dataset, _, _, _ = prepare_dataset(
+ self.cfg_1,
+ tokenizer,
+ processor=processor,
+ )
+
+ self.assertEqual(
+ len(train_dataset),
+ 2000,
+ "Train dataset should have 2000 samples after deduplication.",
+ )
+
+ def test_prepare_dataset_with_deduplication_eval(self):
+ """Verify that prepare_dataset function processes the dataset correctly with deduplication."""
+ self.cfg_1.dataset_exact_deduplication = True
+ self.cfg_1.val_set_size = 0.5
+ # Load tokenizer and processor
+ tokenizer = load_tokenizer(self.cfg_1)
+ processor = (
+ load_processor(self.cfg_1, tokenizer=tokenizer)
+ if self.cfg_1.processor_type
+ else None
+ )
+
+ # Prepare dataset using the prepare_dataset function
+ _, eval_dataset, _, _ = prepare_dataset(
+ self.cfg_1,
+ tokenizer,
+ processor=processor,
+ )
+
+ self.assertEqual(
+ len(eval_dataset),
+ 1000,
+ "Eval dataset should have 2000 samples after deduplication.",
+ )
+
+ def test_prepare_dataset_without_deduplication(self):
+ """Verify that prepare_dataset function processes the dataset correctly without deduplication."""
+ self.cfg_1.dataset_exact_deduplication = False
+ self.cfg_1.val_set_size = 0.1
+ # Load tokenizer and processor
+ tokenizer = load_tokenizer(self.cfg_1)
+ processor = (
+ load_processor(self.cfg_1, tokenizer=tokenizer)
+ if self.cfg_1.processor_type
+ else None
+ )
+
+ # Prepare dataset using the prepare_dataset function
+ train_dataset, eval_dataset, _, _ = prepare_dataset(
+ self.cfg_1,
+ tokenizer,
+ processor=processor,
+ )
+
+ # Verify that the dataset has been prepared correctly
+ self.assertEqual(
+ len(train_dataset),
+ 1800 * 2,
+ "Train dataset should have 3600 samples without deduplication.",
+ )
+ self.assertEqual(
+ len(eval_dataset),
+ 200 * 2,
+ "Train dataset should have 400 samples after deduplication.",
+ )
+
+
+class TestWrongCollisions(unittest.TestCase):
+ """Creating mock datasets for testing wrong collisions"""
+
+ def setUp(self):
+ self.train_data = {"text": ["sample 5", "sample 6"], "label": [1, 2]}
+ self.eval_data = {
+ "text": [
+ "sample 5",
+ "sample 7",
+ ], # Different label but same text as in train_data
+ "label": [2, 3],
+ }
+ self.dataset_data = {
+ "text": ["sample 5", "sample 9", "sample 5"],
+ "label": [1, 2, 8],
+ }
+ self.train_dataset = Dataset.from_dict(self.train_data)
+ self.eval_dataset = Dataset.from_dict(self.eval_data)
+ self.dataset = Dataset.from_dict(self.dataset_data)
+
+ @patch(
+ "axolotl.utils.data.utils.sha256",
+ side_effect=lambda x: hashlib.sha256(
+ "forced_collision_hash".encode("utf-8")
+ ).hexdigest()
+ if "sample 5" in x
+ else hashlib.sha256(x.encode("utf-8")).hexdigest(),
+ )
+ def test_deduplication_wrong_collision_train_eval(self, _mock_sha256):
+ dedup_train, dedup_eval, _ = deduplicate_and_log_datasets(
+ train_dataset=self.train_dataset, eval_dataset=self.eval_dataset
+ )
+ self.assertEqual(
+ len(dedup_train),
+ 2,
+ "train dataset should not deduplicate rows with forced hash collisions but different labels.",
+ )
+ self.assertEqual(
+ len(dedup_eval),
+ 2,
+ "Eval dataset should not deduplicate rows with forced hash collisions but different labels.",
+ )
+ self.assertEqual(
+ len(dedup_eval),
+ len(self.eval_dataset),
+ "The output eval dataset should have the same number of rows as the input eval dataset.",
+ )
+ self.assertEqual(
+ str(dedup_eval),
+ str(self.eval_dataset),
+ "The string representation of the output eval dataset should be identical to the input eval dataset.",
+ )
+
+ def test_deduplication_dataset_only(self):
+ _, _, dedup_dataset = deduplicate_and_log_datasets(dataset=self.dataset)
+ self.assertEqual(
+ len(dedup_dataset), 3, "Dataset should have all original values"
+ )
+ self.assertEqual(
+ str(dedup_dataset),
+ str(self.dataset),
+ "The string representation of the output dataset should not differ.",
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()