From 474208b794c1b614a30a0ba0a41fa1d9af15b0ac Mon Sep 17 00:00:00 2001 From: Manas Vardhan Date: Mon, 2 Mar 2026 09:55:59 -0800 Subject: [PATCH] fix: Save de-duplicated dataset during pre-processing (#3427) * fix: run deduplication before saving dataset during preprocessing Move deduplicate_and_log_datasets call before save_preprocessed_dataset in both SFT and RL data loading pipelines. This ensures the saved preprocessed dataset is already de-duplicated, so subsequent loads from cache don't contain duplicates. Fixes #2719 * fix: include deduplication flag in dataset hash and warn on skip_prepare_dataset+dedup - Add dataset_exact_deduplication to the hash string in generate_dataset_hash_from_config so cached datasets are invalidated when the dedup setting changes. - Log a warning when skip_prepare_dataset=True and dataset_exact_deduplication=True, since dedup will be silently skipped in that configuration (both SFT and RL paths). * fix: add ValueError for skip_prepare+dedup, fix test mock target and formatting - Add config validator (check_deduplication_with_skip_prepare) that raises ValueError when skip_prepare_dataset=True and dataset_exact_deduplication=True - Replace runtime warnings in sft.py/rl.py with the validator check - Fix RL test: patch axolotl.utils.data.rl.load_tokenizer instead of axolotl.loaders.load_tokenizer to properly mock the imported reference - Fix ruff lint (remove unused imports) and formatting issues * refactor: inline deduplicate function per review feedback * fix test fixture, lint --------- Co-authored-by: ManasVardhan Co-authored-by: Wing Lian --- src/axolotl/utils/data/rl.py | 4 + src/axolotl/utils/data/sft.py | 28 ++-- src/axolotl/utils/data/shared.py | 3 +- src/axolotl/utils/schemas/config.py | 13 ++ tests/test_save_deduplicated.py | 210 ++++++++++++++++++++++++++++ 5 files changed, 237 insertions(+), 21 deletions(-) create mode 100644 tests/test_save_deduplicated.py diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 5ea9e55e0..2c386f35e 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -246,6 +246,10 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset: dataset = merge_datasets(split_datasets, cfg) if not cfg.skip_prepare_dataset: + # Deduplicate before saving so the saved dataset is already de-duplicated + if cfg.dataset_exact_deduplication: + dataset, _ = deduplicate_and_log_datasets(dataset=dataset) + # Save preprocessed dataset dataset_hash = generate_dataset_hash_from_config( cfg, datasets_configs, tokenizer.name_or_path diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index ba5aec2d6..69cbfb871 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -351,6 +351,10 @@ def _load_raw_datasets( if cfg.sample_packing: dataset, _ = process_datasets_for_packing(cfg, dataset, None) + # Deduplicate before saving so the saved dataset is already de-duplicated + if cfg.dataset_exact_deduplication: + dataset, _ = deduplicate_and_log_datasets(dataset=dataset) + # Save the prepared dataset dataset_hash = generate_dataset_hash_from_config( cfg, datasets_configs, tokenizer.name_or_path @@ -438,25 +442,8 @@ def _handle_train_dataset_split( ) return train_dataset, eval_dataset - # No validation split - apply deduplication if needed and return as train dataset - if cfg.dataset_exact_deduplication: - train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset) - else: - train_dataset = dataset - - return train_dataset, None - - -def _handle_test_dataset_split( - dataset: Dataset, cfg: DictDefault -) -> tuple[None, Dataset | None]: - """Handle processing for test split.""" - if cfg.dataset_exact_deduplication: - eval_dataset, _ = deduplicate_and_log_datasets(dataset=dataset) - else: - eval_dataset = dataset - - return None, eval_dataset + # No validation split - deduplication already applied during preprocessing + return dataset, None def _apply_dataset_sharding(dataset: Dataset, cfg: DictDefault) -> Dataset: @@ -515,6 +502,7 @@ def _load_and_prepare_datasets( if split == "train": train_dataset, eval_dataset = _handle_train_dataset_split(dataset, cfg) else: - train_dataset, eval_dataset = _handle_test_dataset_split(dataset, cfg) + # Deduplication already applied during preprocessing + train_dataset, eval_dataset = None, dataset return train_dataset, eval_dataset, prompters diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index a8ed55ae2..351669ec3 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -520,7 +520,8 @@ def generate_dataset_hash_from_config( """ config_str = ( f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@" - f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}|" + f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}@" + f"{cfg.dataset_exact_deduplication or False}|" f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}" f"|{tokenizer_name}" ) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 35b4a6908..b15b99955 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1509,3 +1509,16 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): "dataset_exact_deduplication is not available for streaming datasets. " ) return data + + @model_validator(mode="before") + @classmethod + def check_deduplication_with_skip_prepare(cls, data): + if data.get("dataset_exact_deduplication") and data.get("skip_prepare_dataset"): + raise ValueError( + "dataset_exact_deduplication=True has no effect when " + "skip_prepare_dataset=True. Deduplication runs as part of the " + "prepare pipeline, which is skipped. Either set " + "skip_prepare_dataset: false or disable " + "dataset_exact_deduplication." + ) + return data diff --git a/tests/test_save_deduplicated.py b/tests/test_save_deduplicated.py new file mode 100644 index 000000000..1e41c3e10 --- /dev/null +++ b/tests/test_save_deduplicated.py @@ -0,0 +1,210 @@ +"""Tests to verify that deduplication runs before dataset saving during preprocessing. + +This addresses GitHub issue #2719: Save De-duplicated Set During Pre-processing. +""" + +from unittest.mock import MagicMock, patch + +from datasets import Dataset + +from axolotl.utils.dict import DictDefault + + +class TestSFTSaveDeduplicatedBeforeSave: + """Verify that in SFT data loading, deduplication occurs before saving.""" + + @patch("axolotl.utils.data.sft.save_preprocessed_dataset") + @patch("axolotl.utils.data.sft.generate_dataset_hash_from_config") + @patch("axolotl.utils.data.sft.deduplicate_and_log_datasets") + @patch("axolotl.utils.data.sft.merge_datasets") + @patch("axolotl.utils.data.sft._load_and_process_single_dataset") + @patch("axolotl.utils.data.sft.datasets_with_name_generator") + def test_dedup_called_before_save_sft( + self, + mock_datasets_gen, + mock_load_single, + mock_merge, + mock_dedup, + mock_gen_hash, + mock_save, + ): + """Deduplication should be called before save_preprocessed_dataset in SFT.""" + from axolotl.utils.data.sft import _load_raw_datasets + + # Set up mock data + dataset = Dataset.from_dict({"text": ["a", "b", "a"], "label": [1, 2, 1]}) + deduped_dataset = Dataset.from_dict({"text": ["a", "b"], "label": [1, 2]}) + + mock_datasets_gen.return_value = [ + DictDefault({"path": "test", "type": "alpaca"}) + ] + mock_load_single.return_value = (dataset, None) + mock_merge.return_value = dataset + mock_dedup.return_value = (deduped_dataset, None) + mock_gen_hash.return_value = "testhash" + + cfg = DictDefault( + { + "skip_prepare_dataset": False, + "dataset_exact_deduplication": True, + "sequence_len": 1024, + "eval_sequence_len": None, + "sample_packing": False, + "is_preprocess": False, + "seed": 42, + "datasets": [{"path": "test", "type": "alpaca"}], + } + ) + + tokenizer = MagicMock() + tokenizer.name_or_path = "test-tokenizer" + + # Track call order + call_order = [] + mock_dedup.side_effect = lambda **kwargs: ( + call_order.append("dedup") or (deduped_dataset, None) + ) + mock_save.side_effect = lambda *args, **kwargs: call_order.append("save") + + _load_raw_datasets( + cfg=cfg, + datasets_configs=cfg.datasets, + tokenizer=tokenizer, + split="train", + ) + + # Verify dedup was called + assert "dedup" in call_order, "Deduplication should have been called" + # Verify save was called + assert "save" in call_order, "Save should have been called" + # Verify dedup happened before save + assert call_order.index("dedup") < call_order.index("save"), ( + "Deduplication must occur before saving the dataset" + ) + + @patch("axolotl.utils.data.sft.save_preprocessed_dataset") + @patch("axolotl.utils.data.sft.generate_dataset_hash_from_config") + @patch("axolotl.utils.data.sft.merge_datasets") + @patch("axolotl.utils.data.sft._load_and_process_single_dataset") + @patch("axolotl.utils.data.sft.datasets_with_name_generator") + def test_no_dedup_when_disabled_sft( + self, + mock_datasets_gen, + mock_load_single, + mock_merge, + mock_gen_hash, + mock_save, + ): + """Deduplication should not be called when dataset_exact_deduplication is False.""" + from axolotl.utils.data.sft import _load_raw_datasets + + dataset = Dataset.from_dict({"text": ["a", "b", "a"], "label": [1, 2, 1]}) + + mock_datasets_gen.return_value = [ + DictDefault({"path": "test", "type": "alpaca"}) + ] + mock_load_single.return_value = (dataset, None) + mock_merge.return_value = dataset + mock_gen_hash.return_value = "testhash" + + cfg = DictDefault( + { + "skip_prepare_dataset": False, + "dataset_exact_deduplication": False, + "sequence_len": 1024, + "eval_sequence_len": None, + "sample_packing": False, + "is_preprocess": False, + "seed": 42, + "datasets": [{"path": "test", "type": "alpaca"}], + } + ) + + tokenizer = MagicMock() + tokenizer.name_or_path = "test-tokenizer" + + with patch("axolotl.utils.data.sft.deduplicate_and_log_datasets") as mock_dedup: + _load_raw_datasets( + cfg=cfg, + datasets_configs=cfg.datasets, + tokenizer=tokenizer, + split="train", + ) + mock_dedup.assert_not_called() + + +class TestRLSaveDeduplicatedBeforeSave: + """Verify that in RL data loading, deduplication occurs before saving.""" + + @patch.object(Dataset, "filter", lambda self, *args, **kwargs: self) + @patch("axolotl.utils.data.rl.save_preprocessed_dataset") + @patch("axolotl.utils.data.rl.generate_dataset_hash_from_config") + @patch("axolotl.utils.data.rl.deduplicate_and_log_datasets") + @patch("axolotl.utils.data.rl.merge_datasets") + @patch("axolotl.utils.data.rl.load_dataset_with_config") + @patch("axolotl.utils.data.rl.datasets_with_name_generator") + @patch("axolotl.utils.data.rl.load_tokenizer") + def test_dedup_called_before_save_rl( + self, + mock_load_tokenizer, + mock_datasets_gen, + mock_load_dataset, + mock_merge, + mock_dedup, + mock_gen_hash, + mock_save, + ): + """Deduplication should be called before save_preprocessed_dataset in RL.""" + from axolotl.utils.data.rl import _load_split + + dataset = Dataset.from_dict( + { + "prompt": ["hi", "bye", "hi"], + "chosen": ["a", "b", "a"], + "rejected": ["c", "d", "c"], + } + ) + deduped_dataset = Dataset.from_dict( + { + "prompt": ["hi", "bye"], + "chosen": ["a", "b"], + "rejected": ["c", "d"], + } + ) + + mock_datasets_gen.return_value = [DictDefault({"path": "test", "type": None})] + mock_load_dataset.return_value = dataset + mock_merge.return_value = dataset + mock_dedup.return_value = (deduped_dataset, None) + mock_gen_hash.return_value = "testhash" + + tokenizer = MagicMock() + tokenizer.name_or_path = "test-tokenizer" + mock_load_tokenizer.return_value = tokenizer + + cfg = DictDefault( + { + "skip_prepare_dataset": False, + "dataset_exact_deduplication": True, + "sequence_len": 1024, + "rl": "dpo", + "datasets": [{"path": "test", "type": None}], + "hf_use_auth_token": False, + "dataset_num_proc": 1, + "is_preprocess": False, + } + ) + + call_order = [] + mock_dedup.side_effect = lambda **kwargs: ( + call_order.append("dedup") or (deduped_dataset, None) + ) + mock_save.side_effect = lambda *args, **kwargs: call_order.append("save") + + _load_split(cfg, split="train") + + assert "dedup" in call_order, "Deduplication should have been called" + assert "save" in call_order, "Save should have been called" + assert call_order.index("dedup") < call_order.index("save"), ( + "Deduplication must occur before saving the dataset" + )