Files
axolotl/tests/test_save_deduplicated.py
Manas Vardhan 474208b794 fix: Save de-duplicated dataset during pre-processing (#3427)
* fix: run deduplication before saving dataset during preprocessing

Move deduplicate_and_log_datasets call before save_preprocessed_dataset
in both SFT and RL data loading pipelines. This ensures the saved
preprocessed dataset is already de-duplicated, so subsequent loads
from cache don't contain duplicates.

Fixes #2719

* fix: include deduplication flag in dataset hash and warn on skip_prepare_dataset+dedup

- Add dataset_exact_deduplication to the hash string in
  generate_dataset_hash_from_config so cached datasets are invalidated
  when the dedup setting changes.
- Log a warning when skip_prepare_dataset=True and
  dataset_exact_deduplication=True, since dedup will be silently
  skipped in that configuration (both SFT and RL paths).

* fix: add ValueError for skip_prepare+dedup, fix test mock target and formatting

- Add config validator (check_deduplication_with_skip_prepare) that raises
  ValueError when skip_prepare_dataset=True and dataset_exact_deduplication=True
- Replace runtime warnings in sft.py/rl.py with the validator check
- Fix RL test: patch axolotl.utils.data.rl.load_tokenizer instead of
  axolotl.loaders.load_tokenizer to properly mock the imported reference
- Fix ruff lint (remove unused imports) and formatting issues

* refactor: inline deduplicate function per review feedback

* fix test fixture, lint

---------

Co-authored-by: ManasVardhan <manasvardhan@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-03-02 12:55:59 -05:00

211 lines
7.4 KiB
Python

"""Tests to verify that deduplication runs before dataset saving during preprocessing.
This addresses GitHub issue #2719: Save De-duplicated Set During Pre-processing.
"""
from unittest.mock import MagicMock, patch
from datasets import Dataset
from axolotl.utils.dict import DictDefault
class TestSFTSaveDeduplicatedBeforeSave:
"""Verify that in SFT data loading, deduplication occurs before saving."""
@patch("axolotl.utils.data.sft.save_preprocessed_dataset")
@patch("axolotl.utils.data.sft.generate_dataset_hash_from_config")
@patch("axolotl.utils.data.sft.deduplicate_and_log_datasets")
@patch("axolotl.utils.data.sft.merge_datasets")
@patch("axolotl.utils.data.sft._load_and_process_single_dataset")
@patch("axolotl.utils.data.sft.datasets_with_name_generator")
def test_dedup_called_before_save_sft(
self,
mock_datasets_gen,
mock_load_single,
mock_merge,
mock_dedup,
mock_gen_hash,
mock_save,
):
"""Deduplication should be called before save_preprocessed_dataset in SFT."""
from axolotl.utils.data.sft import _load_raw_datasets
# Set up mock data
dataset = Dataset.from_dict({"text": ["a", "b", "a"], "label": [1, 2, 1]})
deduped_dataset = Dataset.from_dict({"text": ["a", "b"], "label": [1, 2]})
mock_datasets_gen.return_value = [
DictDefault({"path": "test", "type": "alpaca"})
]
mock_load_single.return_value = (dataset, None)
mock_merge.return_value = dataset
mock_dedup.return_value = (deduped_dataset, None)
mock_gen_hash.return_value = "testhash"
cfg = DictDefault(
{
"skip_prepare_dataset": False,
"dataset_exact_deduplication": True,
"sequence_len": 1024,
"eval_sequence_len": None,
"sample_packing": False,
"is_preprocess": False,
"seed": 42,
"datasets": [{"path": "test", "type": "alpaca"}],
}
)
tokenizer = MagicMock()
tokenizer.name_or_path = "test-tokenizer"
# Track call order
call_order = []
mock_dedup.side_effect = lambda **kwargs: (
call_order.append("dedup") or (deduped_dataset, None)
)
mock_save.side_effect = lambda *args, **kwargs: call_order.append("save")
_load_raw_datasets(
cfg=cfg,
datasets_configs=cfg.datasets,
tokenizer=tokenizer,
split="train",
)
# Verify dedup was called
assert "dedup" in call_order, "Deduplication should have been called"
# Verify save was called
assert "save" in call_order, "Save should have been called"
# Verify dedup happened before save
assert call_order.index("dedup") < call_order.index("save"), (
"Deduplication must occur before saving the dataset"
)
@patch("axolotl.utils.data.sft.save_preprocessed_dataset")
@patch("axolotl.utils.data.sft.generate_dataset_hash_from_config")
@patch("axolotl.utils.data.sft.merge_datasets")
@patch("axolotl.utils.data.sft._load_and_process_single_dataset")
@patch("axolotl.utils.data.sft.datasets_with_name_generator")
def test_no_dedup_when_disabled_sft(
self,
mock_datasets_gen,
mock_load_single,
mock_merge,
mock_gen_hash,
mock_save,
):
"""Deduplication should not be called when dataset_exact_deduplication is False."""
from axolotl.utils.data.sft import _load_raw_datasets
dataset = Dataset.from_dict({"text": ["a", "b", "a"], "label": [1, 2, 1]})
mock_datasets_gen.return_value = [
DictDefault({"path": "test", "type": "alpaca"})
]
mock_load_single.return_value = (dataset, None)
mock_merge.return_value = dataset
mock_gen_hash.return_value = "testhash"
cfg = DictDefault(
{
"skip_prepare_dataset": False,
"dataset_exact_deduplication": False,
"sequence_len": 1024,
"eval_sequence_len": None,
"sample_packing": False,
"is_preprocess": False,
"seed": 42,
"datasets": [{"path": "test", "type": "alpaca"}],
}
)
tokenizer = MagicMock()
tokenizer.name_or_path = "test-tokenizer"
with patch("axolotl.utils.data.sft.deduplicate_and_log_datasets") as mock_dedup:
_load_raw_datasets(
cfg=cfg,
datasets_configs=cfg.datasets,
tokenizer=tokenizer,
split="train",
)
mock_dedup.assert_not_called()
class TestRLSaveDeduplicatedBeforeSave:
"""Verify that in RL data loading, deduplication occurs before saving."""
@patch.object(Dataset, "filter", lambda self, *args, **kwargs: self)
@patch("axolotl.utils.data.rl.save_preprocessed_dataset")
@patch("axolotl.utils.data.rl.generate_dataset_hash_from_config")
@patch("axolotl.utils.data.rl.deduplicate_and_log_datasets")
@patch("axolotl.utils.data.rl.merge_datasets")
@patch("axolotl.utils.data.rl.load_dataset_with_config")
@patch("axolotl.utils.data.rl.datasets_with_name_generator")
@patch("axolotl.utils.data.rl.load_tokenizer")
def test_dedup_called_before_save_rl(
self,
mock_load_tokenizer,
mock_datasets_gen,
mock_load_dataset,
mock_merge,
mock_dedup,
mock_gen_hash,
mock_save,
):
"""Deduplication should be called before save_preprocessed_dataset in RL."""
from axolotl.utils.data.rl import _load_split
dataset = Dataset.from_dict(
{
"prompt": ["hi", "bye", "hi"],
"chosen": ["a", "b", "a"],
"rejected": ["c", "d", "c"],
}
)
deduped_dataset = Dataset.from_dict(
{
"prompt": ["hi", "bye"],
"chosen": ["a", "b"],
"rejected": ["c", "d"],
}
)
mock_datasets_gen.return_value = [DictDefault({"path": "test", "type": None})]
mock_load_dataset.return_value = dataset
mock_merge.return_value = dataset
mock_dedup.return_value = (deduped_dataset, None)
mock_gen_hash.return_value = "testhash"
tokenizer = MagicMock()
tokenizer.name_or_path = "test-tokenizer"
mock_load_tokenizer.return_value = tokenizer
cfg = DictDefault(
{
"skip_prepare_dataset": False,
"dataset_exact_deduplication": True,
"sequence_len": 1024,
"rl": "dpo",
"datasets": [{"path": "test", "type": None}],
"hf_use_auth_token": False,
"dataset_num_proc": 1,
"is_preprocess": False,
}
)
call_order = []
mock_dedup.side_effect = lambda **kwargs: (
call_order.append("dedup") or (deduped_dataset, None)
)
mock_save.side_effect = lambda *args, **kwargs: call_order.append("save")
_load_split(cfg, split="train")
assert "dedup" in call_order, "Deduplication should have been called"
assert "save" in call_order, "Save should have been called"
assert call_order.index("dedup") < call_order.index("save"), (
"Deduplication must occur before saving the dataset"
)