From b0fbd4d11dea5171b177d4341645c0e4b821db1e Mon Sep 17 00:00:00 2001 From: Oliver Molenschot <91694286+olivermolenschot@users.noreply.github.com> Date: Mon, 2 Dec 2024 05:47:10 -0800 Subject: [PATCH] Add Exact Deduplication Feature to Preprocessing Pipeline (#2072) * Add example YAML file for training Mistral using DPO * added deduplication code * Add exact deduplication feature and update examples * Improve deduplication for train/eval overlap Changed the deduplication function to use a more memory-efficient hashing method. Applied Git suggestions to improve clarity and maintainability.\n\nThe deduplication now handles cases where train and eval datasets have overlapping elements. * Improve deduplication for train/eval overlap Changed the deduplication function to use a more memory-efficient hashing method. Applied Git suggestions to improve clarity and maintainability.\n\nThe deduplication now handles cases where train and eval datasets have overlapping elements. * Apply suggestions from code review To handle the original case where we do not do deduplication Co-authored-by: Wing Lian * Improve false collision detection to ensure dataset integrity - Added test cases to simulate and verify handling of forced hash collisions between datasets. - Ensured that datasets with identical hashes but different content are correctly identified, preventing incorrect deduplication. - Updated unit tests to include scenarios where collisions occur across both training and evaluation datasets, as well as within a single dataset. * Moved the constants file to the tests folder - Relocated `constants.py` to the `tests` folder to improve modularity and maintain a clear separation between source and test files. - Renamed `cicd/tests.py` to `cicd/cicd_tests.py` to resolve a conflict with `tests/__init__.py`, which caused Mypy to fail due to duplicate module names. - Updated all references to `cicd.tests` in the codebase to `cicd.cicd_tests` to reflect the renaming and ensure compatibility. - These changes ensure Mypy passes the pre-commit hook and maintain alignment with the project's structure. * revert some changes from previous commit and fix relative import --------- Co-authored-by: Wing Lian Co-authored-by: Wing Lian --- docs/config.qmd | 3 + examples/llama-3/lora-1b-deduplicate-dpo.yml | 95 ++++ examples/llama-3/lora-1b-deduplicate-sft.yml | 76 +++ src/axolotl/cli/__init__.py | 2 +- .../config/models/input/v0_4_1/__init__.py | 1 + src/axolotl/utils/data/rl.py | 7 +- src/axolotl/utils/data/sft.py | 21 +- src/axolotl/utils/data/utils.py | 98 ++++ tests/constants.py | 32 ++ tests/test_datasets.py | 50 +- tests/test_exact_deduplication.py | 433 ++++++++++++++++++ 11 files changed, 767 insertions(+), 51 deletions(-) create mode 100644 examples/llama-3/lora-1b-deduplicate-dpo.yml create mode 100644 examples/llama-3/lora-1b-deduplicate-sft.yml create mode 100644 tests/constants.py create mode 100644 tests/test_exact_deduplication.py diff --git a/docs/config.qmd b/docs/config.qmd index 04e278e2d..bc3730095 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -162,6 +162,9 @@ datasets: # The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true. shuffle_merged_datasets: true +Deduplicates datasets and test_datasets with identical entries. +dataset_exact_deduplication: true + # A list of one or more datasets to eval the model with. # You can use either test_datasets, or val_set_size, but not both. test_datasets: diff --git a/examples/llama-3/lora-1b-deduplicate-dpo.yml b/examples/llama-3/lora-1b-deduplicate-dpo.yml new file mode 100644 index 000000000..35a0260ca --- /dev/null +++ b/examples/llama-3/lora-1b-deduplicate-dpo.yml @@ -0,0 +1,95 @@ +base_model: meta-llama/Llama-3.2-1B +model_type: LlamaForCausalLM +tokenizer_type: AutoTokenizer + +load_in_8bit: true +load_in_4bit: false +strict: false + +chat_template: llama3 +rl: dpo +datasets: + - path: fozziethebeat/alpaca_messages_2k_dpo_test + type: chat_template.default + field_messages: conversation + field_chosen: chosen + field_rejected: rejected + message_field_role: role + message_field_content: content + roles: + system: + - system + user: + - user + assistant: + - assistant + - path: fozziethebeat/alpaca_messages_2k_dpo_test + type: chat_template.default + field_messages: conversation + field_chosen: chosen + field_rejected: rejected + message_field_role: role + message_field_content: content + roles: + system: + - system + user: + - user + assistant: + - assistant + +dataset_exact_deduplication: true +dataset_prepared_path: +val_set_size: 0 +output_dir: ./outputs/lora-out + +sequence_len: 4096 +sample_packing: false +pad_to_sequence_len: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true +s2_attention: + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: diff --git a/examples/llama-3/lora-1b-deduplicate-sft.yml b/examples/llama-3/lora-1b-deduplicate-sft.yml new file mode 100644 index 000000000..c07d5f8ff --- /dev/null +++ b/examples/llama-3/lora-1b-deduplicate-sft.yml @@ -0,0 +1,76 @@ +base_model: meta-llama/Llama-3.2-1B +model_type: LlamaForCausalLM +tokenizer_type: AutoTokenizer + +load_in_8bit: true +load_in_4bit: false +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: +val_set_size: 0.0 +output_dir: ./outputs/lora-out + +dataset_exact_deduplication: true +test_value: true + +sequence_len: 4096 +sample_packing: true +eval_sample_packing: false +pad_to_sequence_len: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: +lora_modules_to_save: + - embed_tokens + - lm_head + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true +s2_attention: + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + pad_token: <|end_of_text|> diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 7c8db7ce8..86cc30a40 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -139,7 +139,7 @@ def check_remote_config(config: Union[str, Path]): with open(output_path, "wb") as file: file.write(content) LOG.info( - f"Using the following config obtained from {config}:\n\n{content.decode('utf-8')}\n" + f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n" ) return output_path diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index c7d5848f9..378dfef86 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -625,6 +625,7 @@ class AxolotlInputConfig( json_schema_extra={"description": "streaming dataset to use for pretraining"}, ) dataset_processes: Optional[int] = Field(default=os.cpu_count()) + dataset_exact_deduplication: Optional[bool] = None dataset_keep_in_memory: Optional[bool] = None dataloader_pin_memory: Optional[bool] = None dataloader_num_workers: Optional[int] = None diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index cf1226175..edb72f186 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -13,7 +13,7 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.prompt_strategies.dpo import load as load_dpo from axolotl.prompt_strategies.kto import load as load_kto from axolotl.prompt_strategies.orpo import load as load_orpo -from axolotl.utils.data.utils import md5 +from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5 from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.models import load_tokenizer @@ -208,4 +208,9 @@ def load_prepare_dpo_datasets(cfg): if eval_dataset and not eval_is_preprocessed: _save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset) + if cfg.dataset_exact_deduplication: + train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( + train_dataset=train_dataset, eval_dataset=eval_dataset + ) + return train_dataset, eval_dataset diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index b72e0a2b3..0bee4dd5c 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -44,7 +44,7 @@ from axolotl.prompters import ( UnsupportedPrompter, ) from axolotl.utils.data.pretraining import wrap_pretraining_dataset -from axolotl.utils.data.utils import md5 +from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5 from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_local_main_process, zero_first from axolotl.utils.trainer import ( @@ -136,8 +136,9 @@ def prepare_dataset(cfg, tokenizer, processor=None): # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 train_dataset = train_dataset.with_format("torch") eval_dataset = None + if cfg.dataset_exact_deduplication: + LOG.info("Deduplication not available for pretrained datasets") return train_dataset, eval_dataset, cfg.max_steps, prompters - if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False: total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) if total_eval_steps == 0: @@ -178,7 +179,7 @@ def load_tokenized_prepared_datasets( + "|".join( sorted( [ - f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}" + f"{d.path}: {d.type}: {d.shards}: {d.conversation}{d.split}" for d in cfg_datasets ] ) @@ -584,7 +585,8 @@ def load_prepare_datasets( ) train_fingerprint = md5(to_hash_train) test_fingerprint = md5(to_hash_test) - + if cfg.dataset_exact_deduplication: + _, _, dataset = deduplicate_and_log_datasets(dataset=dataset) dataset = dataset.train_test_split( test_size=val_set_size, shuffle=False, @@ -596,12 +598,17 @@ def load_prepare_datasets( train_dataset = dataset["train"] eval_dataset = dataset["test"] elif split == "test": + if cfg.dataset_exact_deduplication: + _, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset) + else: + eval_dataset = dataset train_dataset = None - eval_dataset = dataset else: - train_dataset = dataset + if cfg.dataset_exact_deduplication: + train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset) + else: + train_dataset = dataset eval_dataset = None - return train_dataset, eval_dataset, prompters diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index e05701e7b..56bcddd8e 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -1,6 +1,11 @@ """data handling helpers""" import hashlib +import logging + +from datasets import Dataset + +LOG = logging.getLogger("axolotl") def md5(to_hash: str, encoding: str = "utf-8") -> str: @@ -8,3 +13,96 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str: return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest() except TypeError: return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec + + +def sha256(to_hash: str, encoding: str = "utf-8") -> str: + return hashlib.sha256(to_hash.encode(encoding)).hexdigest() + + +def deduplicate_dataset( + dataset: Dataset, seen_hashes: dict[str, list[int]], other_dataset: Dataset = None +) -> Dataset: + unique_indices = [] + + for idx, row in enumerate(dataset): + row_hash = sha256(str(row)) # Using SHA256 for collision resistance. + if row_hash not in seen_hashes: + seen_hashes[row_hash] = [idx] + unique_indices.append(idx) + else: + # Check for collision by looking up the original dataset indices + original_indices = seen_hashes[row_hash] + is_duplicate = False + for original_idx in original_indices: + if ( + not idx == original_idx + and original_idx < len(dataset) + and str(dataset[original_idx]) == str(row) + ): + is_duplicate = True + break + # Check in the other dataset if provided + if other_dataset is not None: + if original_idx < len(other_dataset) and str( + other_dataset[original_idx] + ) == str(row): + is_duplicate = True + break + if not is_duplicate: + seen_hashes[row_hash].append(idx) + unique_indices.append(idx) + continue + return dataset.select(unique_indices) + + +def deduplicate_and_log_datasets( + *, + train_dataset: Dataset = None, + eval_dataset: Dataset = None, + dataset: Dataset = None, +) -> tuple[Dataset, Dataset, Dataset]: + """ + Deduplicates train, eval, and an optional dataset if provided, logging original and new sizes. + + Returns: + tuple: Deduplicated train, eval, and additional datasets. + """ + seen_hashes: dict[str, list[int]] = {} + + # Handle cases where datasets are None + if train_dataset is not None: + LOG.info( + f"Starting deduplication for train dataset. Original size: {len(train_dataset)}" + ) + train_dataset = deduplicate_dataset( + dataset=train_dataset, seen_hashes=seen_hashes + ) + LOG.info( + f"Deduplication complete for train dataset. New size: {len(train_dataset)}" + ) + else: + LOG.info("Train dataset is None. Skipping deduplication.") + + if eval_dataset is not None: + LOG.info( + f"Starting deduplication for eval dataset. Original size: {len(eval_dataset)}" + ) + eval_dataset = deduplicate_dataset( + dataset=eval_dataset, seen_hashes=seen_hashes, other_dataset=train_dataset + ) + LOG.info( + f"Deduplication complete for eval dataset. New size: {len(eval_dataset)}" + ) + else: + LOG.info("Eval dataset is None. Skipping deduplication.") + + if dataset is not None and (eval_dataset is None and train_dataset is None): + LOG.info( + f"Starting deduplication for combined dataset. Original size: {len(dataset)}" + ) + dataset = deduplicate_dataset(dataset=dataset, seen_hashes=seen_hashes) + LOG.info( + f"Deduplication complete for combined dataset. New size: {len(dataset)}" + ) + + return train_dataset, eval_dataset, dataset diff --git a/tests/constants.py b/tests/constants.py new file mode 100644 index 000000000..e024e6920 --- /dev/null +++ b/tests/constants.py @@ -0,0 +1,32 @@ +# constants.py +""" +This module contains constants and configuration dictionaries used for +datasets and other utilities in the Axolotl project, specifically for testing. +""" +# Configuration for Alpaca Messages Dataset +ALPACA_MESSAGES_CONFIG_OG = { + "path": "fozziethebeat/alpaca_messages_2k_dpo_test", + "type": "chat_template.default", + "chat_template": "llama3", + "field_messages": "conversation", + "field_chosen": "chosen", + "field_rejected": "rejected", + "message_field_role": "role", + "message_field_content": "content", + "roles": { + "system": ["system"], + "user": ["user"], + "assistant": ["assistant"], + }, +} + +# Revision configuration extending the original +ALPACA_MESSAGES_CONFIG_REVISION = ALPACA_MESSAGES_CONFIG_OG.copy() +ALPACA_MESSAGES_CONFIG_REVISION["revision"] = "ea82cff" + + +SPECIAL_TOKENS = { + "bos_token": "", + "eos_token": "", + "unk_token": "", +} diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e87f19cc7..f3bed00fd 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -7,6 +7,11 @@ import tempfile import unittest from pathlib import Path +from constants import ( + ALPACA_MESSAGES_CONFIG_OG, + ALPACA_MESSAGES_CONFIG_REVISION, + SPECIAL_TOKENS, +) from datasets import Dataset from huggingface_hub import snapshot_download from transformers import AutoTokenizer @@ -21,13 +26,7 @@ class TestDatasetPreparation(unittest.TestCase): def setUp(self) -> None: self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") - self.tokenizer.add_special_tokens( - { - "bos_token": "", - "eos_token": "", - "unk_token": "", - } - ) + self.tokenizer.add_special_tokens(SPECIAL_TOKENS) # Alpaca dataset. self.dataset = Dataset.from_list( [ @@ -277,23 +276,7 @@ class TestDatasetPreparation(unittest.TestCase): "sequence_len": 1024, "rl": "dpo", "chat_template": "llama3", - "datasets": [ - { - "path": "fozziethebeat/alpaca_messages_2k_dpo_test", - "type": "chat_template.default", - "chat_template": "llama3", - "field_messages": "conversation", - "field_chosen": "chosen", - "field_rejected": "rejected", - "message_field_role": "role", - "message_field_content": "content", - "roles": { - "system": ["system"], - "user": ["user"], - "assistant": ["assistant"], - }, - } - ], + "datasets": [ALPACA_MESSAGES_CONFIG_OG], } ) @@ -342,24 +325,7 @@ class TestDatasetPreparation(unittest.TestCase): "sequence_len": 1024, "rl": "dpo", "chat_template": "llama3", - "datasets": [ - { - "path": "fozziethebeat/alpaca_messages_2k_dpo_test", - "type": "chat_template.default", - "chat_template": "llama3", - "revision": "ea82cff", - "field_messages": "conversation", - "field_chosen": "chosen", - "field_rejected": "rejected", - "message_field_role": "role", - "message_field_content": "content", - "roles": { - "system": ["system"], - "user": ["user"], - "assistant": ["assistant"], - }, - } - ], + "datasets": [ALPACA_MESSAGES_CONFIG_REVISION], } ) diff --git a/tests/test_exact_deduplication.py b/tests/test_exact_deduplication.py new file mode 100644 index 000000000..2ac6415be --- /dev/null +++ b/tests/test_exact_deduplication.py @@ -0,0 +1,433 @@ +""" +Test suite for functions in the axolotl.utils.data.utils module, focusing on the deduplicate_and_log_datasets function. + +Additionally, this test suite includes tests for functions that indirectly call deduplicate_and_log_datasets during the execution of the preprocess command. +""" +import hashlib +import unittest +from unittest.mock import patch + +from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS +from datasets import Dataset +from transformers import AutoTokenizer + +from axolotl.utils.data import prepare_dataset +from axolotl.utils.data.rl import load_prepare_dpo_datasets +from axolotl.utils.data.utils import deduplicate_and_log_datasets +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_processor, load_tokenizer + + +def verify_deduplication(actual_dataset, expected_dataset, dataset_name): + """ + Validates deduplication results and size consistency. + + Parameters: + - actual_dataset: Deduplicated dataset. + - expected_dataset: Expected dataset. + - dataset_name: Name of the dataset (e.g., 'train' or 'eval'). + + Asserts: + - Datasets match in content. + - Dataset size matches unique row count. + """ + # Convert datasets to sets of tuples for unordered comparison + actual_rows = set(tuple(row.values()) for row in actual_dataset) + expected_rows = set(tuple(row.values()) for row in expected_dataset) + + # Verify deduplication correctness + assert actual_rows == expected_rows, f"Mismatch in {dataset_name} dataset" + + # Verify size consistency + assert len(actual_rows) == len( + actual_dataset + ), f"Size mismatch in {dataset_name} dataset after deduplication" + + +class TestDeduplicateIndividualFunctions(unittest.TestCase): + """ + test class for deduplication function in data utils + """ + + def setUp(self): + # Sample data with duplicates + self.data = { + "column1": ["apple", "banana", "apple", "orange", "banana"], + "column2": [1, 2, 1, 3, 2], + "column3": ["red", "yellow", "red", "orange", "yellow"], + } + + # Expected result after deduplication + self.expected_data = { + "column1": ["apple", "banana", "orange"], + "column2": [1, 2, 3], + "column3": ["red", "yellow", "orange"], + } + + # Convert to Dataset format + self.dataset = Dataset.from_dict(self.data) + self.expected_dataset = Dataset.from_dict(self.expected_data) + + def test_deduplication(self): + train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=self.dataset) + _, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=self.dataset) + + verify_deduplication(train_dataset, self.expected_dataset, "train_dataset") + verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset") + + def test_datasets_are_none(self): + # Test when both datasets are None + train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( + train_dataset=None, eval_dataset=None + ) + self.assertIsNone(train_dataset, "Expected train_dataset to be None") + self.assertIsNone(eval_dataset, "Expected eval_dataset to be None") + + def test_only_train_is_none(self): + # Test when only train_dataset is None + train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( + train_dataset=None, eval_dataset=self.dataset + ) + self.assertIsNone(train_dataset, "Expected train_dataset to be None") + verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset") + + def test_only_eval_is_none(self): + # Test when only eval_dataset is None + train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( + train_dataset=self.dataset, eval_dataset=None + ) + self.assertIsNone(eval_dataset, "Expected eval_dataset to be None") + verify_deduplication(train_dataset, self.expected_dataset, "train_dataset") + + def test_exact_duplicates(self): + # Test when datasets are exact duplicates + duplicate_data = { + "column1": ["apple", "apple", "apple"], + "column2": [1, 1, 1], + "column3": ["red", "red", "red"], + } + expected_data = {"column1": ["apple"], "column2": [1], "column3": ["red"]} + + # Convert to Dataset format + dataset = Dataset.from_dict(duplicate_data) + expected_dataset = Dataset.from_dict(expected_data) + + # Run deduplication + train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset) + _, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset) + + verify_deduplication(train_dataset, expected_dataset, "train_dataset") + verify_deduplication(eval_dataset, expected_dataset, "eval_dataset") + + def test_partial_duplicates(self): + # Test when only part of the dataset is a duplicate + partial_duplicate_data = { + "column1": ["apple", "banana", "apple"], + "column2": [1, 2, 1], + "column3": ["red", "yellow", "red"], + } + expected_data = { + "column1": ["apple", "banana"], + "column2": [1, 2], + "column3": ["red", "yellow"], + } + + # Convert to Dataset format + dataset = Dataset.from_dict(partial_duplicate_data) + expected_dataset = Dataset.from_dict(expected_data) + + # Run deduplication + train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset) + _, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset) + + verify_deduplication(train_dataset, expected_dataset, "train_dataset") + verify_deduplication(eval_dataset, expected_dataset, "eval_dataset") + + def test_combined_duplicates_empty(self): + # Test when only part of the dataset is a duplicate + partial_duplicate_data = { + "column1": ["apple", "banana", "apple"], + "column2": [1, 2, 1], + "column3": ["red", "yellow", "red"], + } + expected_data_train = { + "column1": ["apple", "banana"], + "column2": [1, 2], + "column3": ["red", "yellow"], + } + expected_data_eval = { + "column1": [], + "column2": [], + "column3": [], + } + + # Convert to Dataset format + dataset = Dataset.from_dict(partial_duplicate_data) + expected_dataset_train = Dataset.from_dict(expected_data_train) + expected_dataset_eval = Dataset.from_dict(expected_data_eval) + + # Run deduplication + train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( + train_dataset=dataset, eval_dataset=dataset + ) + + verify_deduplication(train_dataset, expected_dataset_train, "train_dataset") + verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset") + + def test_combined_duplicates_one(self): + # Test when only part of the dataset is a duplicate + partial_duplicate_data_train = { + "column1": ["apple", "banana", "apple"], + "column2": [1, 2, 1], + "column3": ["red", "yellow", "red"], + } + partial_duplicate_data_eval = { + "column1": ["apple", "orange", "apple"], + "column2": [1, 2, 1], + "column3": ["red", "orange", "red"], + } + expected_data_train = { + "column1": ["apple", "banana"], + "column2": [1, 2], + "column3": ["red", "yellow"], + } + expected_data_eval = { + "column1": ["orange"], + "column2": [2], + "column3": ["orange"], + } + + # Convert to Dataset format + dataset_train = Dataset.from_dict(partial_duplicate_data_train) + dataset_eval = Dataset.from_dict(partial_duplicate_data_eval) + expected_dataset_train = Dataset.from_dict(expected_data_train) + expected_dataset_eval = Dataset.from_dict(expected_data_eval) + + # Run deduplication + train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( + train_dataset=dataset_train, eval_dataset=dataset_eval + ) + + verify_deduplication(train_dataset, expected_dataset_train, "train_dataset") + verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset") + + +class TestDeduplicateRLDataset(unittest.TestCase): + """Test a configured dataloader with deduplication.""" + + def setUp(self) -> None: + self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") + self.tokenizer.add_special_tokens(SPECIAL_TOKENS) + self.cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "rl": "dpo", + "chat_template": "llama3", + "dataset_exact_deduplication": True, + "datasets": [ + ALPACA_MESSAGES_CONFIG_REVISION, + ALPACA_MESSAGES_CONFIG_REVISION, + ], + } + ) + + def test_load_with_deduplication(self): + """Verify that loading with deduplication removes duplicates.""" + + # Load the dataset using the deduplication setting + train_dataset, _ = load_prepare_dpo_datasets(self.cfg) + + # Verify that the dataset has been deduplicated + assert len(train_dataset) == 1800, "Dataset was not properly deduplicated" + + def test_load_without_deduplication(self): + """Verify that loading without deduplication retains duplicates.""" + self.cfg.dataset_exact_deduplication = False + # Load the dataset without deduplication + train_dataset, _ = load_prepare_dpo_datasets(self.cfg) + + # Verify that the dataset retains duplicates + assert ( + len(train_dataset) == 1800 * 2 + ), "Dataset deduplication occurred when it should not have" + + +class TestDeduplicateNonRL(unittest.TestCase): + """Test prepare_dataset function with different configurations.""" + + def setUp(self) -> None: + self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") + self.tokenizer.add_special_tokens(SPECIAL_TOKENS) + self.cfg_1 = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 1024, + "dataset_exact_deduplication": True, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "val_set_size": 0.0, + "gradient_accumulation_steps": 4, + "batch_size": 10, + "micro_batch_size": 10, + "num_epochs": 1, + } + ) + + def test_prepare_dataset_with_deduplication_train(self): + """Verify that prepare_dataset function processes the dataset correctly with deduplication.""" + self.cfg_1.dataset_exact_deduplication = True + + # Load tokenizer and processor + tokenizer = load_tokenizer(self.cfg_1) + processor = ( + load_processor(self.cfg_1, tokenizer=tokenizer) + if self.cfg_1.processor_type + else None + ) + + # Prepare dataset using the prepare_dataset function + train_dataset, _, _, _ = prepare_dataset( + self.cfg_1, + tokenizer, + processor=processor, + ) + + self.assertEqual( + len(train_dataset), + 2000, + "Train dataset should have 2000 samples after deduplication.", + ) + + def test_prepare_dataset_with_deduplication_eval(self): + """Verify that prepare_dataset function processes the dataset correctly with deduplication.""" + self.cfg_1.dataset_exact_deduplication = True + self.cfg_1.val_set_size = 0.5 + # Load tokenizer and processor + tokenizer = load_tokenizer(self.cfg_1) + processor = ( + load_processor(self.cfg_1, tokenizer=tokenizer) + if self.cfg_1.processor_type + else None + ) + + # Prepare dataset using the prepare_dataset function + _, eval_dataset, _, _ = prepare_dataset( + self.cfg_1, + tokenizer, + processor=processor, + ) + + self.assertEqual( + len(eval_dataset), + 1000, + "Eval dataset should have 2000 samples after deduplication.", + ) + + def test_prepare_dataset_without_deduplication(self): + """Verify that prepare_dataset function processes the dataset correctly without deduplication.""" + self.cfg_1.dataset_exact_deduplication = False + self.cfg_1.val_set_size = 0.1 + # Load tokenizer and processor + tokenizer = load_tokenizer(self.cfg_1) + processor = ( + load_processor(self.cfg_1, tokenizer=tokenizer) + if self.cfg_1.processor_type + else None + ) + + # Prepare dataset using the prepare_dataset function + train_dataset, eval_dataset, _, _ = prepare_dataset( + self.cfg_1, + tokenizer, + processor=processor, + ) + + # Verify that the dataset has been prepared correctly + self.assertEqual( + len(train_dataset), + 1800 * 2, + "Train dataset should have 3600 samples without deduplication.", + ) + self.assertEqual( + len(eval_dataset), + 200 * 2, + "Train dataset should have 400 samples after deduplication.", + ) + + +class TestWrongCollisions(unittest.TestCase): + """Creating mock datasets for testing wrong collisions""" + + def setUp(self): + self.train_data = {"text": ["sample 5", "sample 6"], "label": [1, 2]} + self.eval_data = { + "text": [ + "sample 5", + "sample 7", + ], # Different label but same text as in train_data + "label": [2, 3], + } + self.dataset_data = { + "text": ["sample 5", "sample 9", "sample 5"], + "label": [1, 2, 8], + } + self.train_dataset = Dataset.from_dict(self.train_data) + self.eval_dataset = Dataset.from_dict(self.eval_data) + self.dataset = Dataset.from_dict(self.dataset_data) + + @patch( + "axolotl.utils.data.utils.sha256", + side_effect=lambda x: hashlib.sha256( + "forced_collision_hash".encode("utf-8") + ).hexdigest() + if "sample 5" in x + else hashlib.sha256(x.encode("utf-8")).hexdigest(), + ) + def test_deduplication_wrong_collision_train_eval(self, _mock_sha256): + dedup_train, dedup_eval, _ = deduplicate_and_log_datasets( + train_dataset=self.train_dataset, eval_dataset=self.eval_dataset + ) + self.assertEqual( + len(dedup_train), + 2, + "train dataset should not deduplicate rows with forced hash collisions but different labels.", + ) + self.assertEqual( + len(dedup_eval), + 2, + "Eval dataset should not deduplicate rows with forced hash collisions but different labels.", + ) + self.assertEqual( + len(dedup_eval), + len(self.eval_dataset), + "The output eval dataset should have the same number of rows as the input eval dataset.", + ) + self.assertEqual( + str(dedup_eval), + str(self.eval_dataset), + "The string representation of the output eval dataset should be identical to the input eval dataset.", + ) + + def test_deduplication_dataset_only(self): + _, _, dedup_dataset = deduplicate_and_log_datasets(dataset=self.dataset) + self.assertEqual( + len(dedup_dataset), 3, "Dataset should have all original values" + ) + self.assertEqual( + str(dedup_dataset), + str(self.dataset), + "The string representation of the output dataset should not differ.", + ) + + +if __name__ == "__main__": + unittest.main()