fix: Save de-duplicated dataset during pre-processing (#3427)
* fix: run deduplication before saving dataset during preprocessing Move deduplicate_and_log_datasets call before save_preprocessed_dataset in both SFT and RL data loading pipelines. This ensures the saved preprocessed dataset is already de-duplicated, so subsequent loads from cache don't contain duplicates. Fixes #2719 * fix: include deduplication flag in dataset hash and warn on skip_prepare_dataset+dedup - Add dataset_exact_deduplication to the hash string in generate_dataset_hash_from_config so cached datasets are invalidated when the dedup setting changes. - Log a warning when skip_prepare_dataset=True and dataset_exact_deduplication=True, since dedup will be silently skipped in that configuration (both SFT and RL paths). * fix: add ValueError for skip_prepare+dedup, fix test mock target and formatting - Add config validator (check_deduplication_with_skip_prepare) that raises ValueError when skip_prepare_dataset=True and dataset_exact_deduplication=True - Replace runtime warnings in sft.py/rl.py with the validator check - Fix RL test: patch axolotl.utils.data.rl.load_tokenizer instead of axolotl.loaders.load_tokenizer to properly mock the imported reference - Fix ruff lint (remove unused imports) and formatting issues * refactor: inline deduplicate function per review feedback * fix test fixture, lint --------- Co-authored-by: ManasVardhan <manasvardhan@users.noreply.github.com> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -246,6 +246,10 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
|||||||
dataset = merge_datasets(split_datasets, cfg)
|
dataset = merge_datasets(split_datasets, cfg)
|
||||||
|
|
||||||
if not cfg.skip_prepare_dataset:
|
if not cfg.skip_prepare_dataset:
|
||||||
|
# Deduplicate before saving so the saved dataset is already de-duplicated
|
||||||
|
if cfg.dataset_exact_deduplication:
|
||||||
|
dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
||||||
|
|
||||||
# Save preprocessed dataset
|
# Save preprocessed dataset
|
||||||
dataset_hash = generate_dataset_hash_from_config(
|
dataset_hash = generate_dataset_hash_from_config(
|
||||||
cfg, datasets_configs, tokenizer.name_or_path
|
cfg, datasets_configs, tokenizer.name_or_path
|
||||||
|
|||||||
@@ -351,6 +351,10 @@ def _load_raw_datasets(
|
|||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||||
|
|
||||||
|
# Deduplicate before saving so the saved dataset is already de-duplicated
|
||||||
|
if cfg.dataset_exact_deduplication:
|
||||||
|
dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
||||||
|
|
||||||
# Save the prepared dataset
|
# Save the prepared dataset
|
||||||
dataset_hash = generate_dataset_hash_from_config(
|
dataset_hash = generate_dataset_hash_from_config(
|
||||||
cfg, datasets_configs, tokenizer.name_or_path
|
cfg, datasets_configs, tokenizer.name_or_path
|
||||||
@@ -438,25 +442,8 @@ def _handle_train_dataset_split(
|
|||||||
)
|
)
|
||||||
return train_dataset, eval_dataset
|
return train_dataset, eval_dataset
|
||||||
|
|
||||||
# No validation split - apply deduplication if needed and return as train dataset
|
# No validation split - deduplication already applied during preprocessing
|
||||||
if cfg.dataset_exact_deduplication:
|
return dataset, None
|
||||||
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
|
||||||
else:
|
|
||||||
train_dataset = dataset
|
|
||||||
|
|
||||||
return train_dataset, None
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_test_dataset_split(
|
|
||||||
dataset: Dataset, cfg: DictDefault
|
|
||||||
) -> tuple[None, Dataset | None]:
|
|
||||||
"""Handle processing for test split."""
|
|
||||||
if cfg.dataset_exact_deduplication:
|
|
||||||
eval_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
|
||||||
else:
|
|
||||||
eval_dataset = dataset
|
|
||||||
|
|
||||||
return None, eval_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_dataset_sharding(dataset: Dataset, cfg: DictDefault) -> Dataset:
|
def _apply_dataset_sharding(dataset: Dataset, cfg: DictDefault) -> Dataset:
|
||||||
@@ -515,6 +502,7 @@ def _load_and_prepare_datasets(
|
|||||||
if split == "train":
|
if split == "train":
|
||||||
train_dataset, eval_dataset = _handle_train_dataset_split(dataset, cfg)
|
train_dataset, eval_dataset = _handle_train_dataset_split(dataset, cfg)
|
||||||
else:
|
else:
|
||||||
train_dataset, eval_dataset = _handle_test_dataset_split(dataset, cfg)
|
# Deduplication already applied during preprocessing
|
||||||
|
train_dataset, eval_dataset = None, dataset
|
||||||
|
|
||||||
return train_dataset, eval_dataset, prompters
|
return train_dataset, eval_dataset, prompters
|
||||||
|
|||||||
@@ -520,7 +520,8 @@ def generate_dataset_hash_from_config(
|
|||||||
"""
|
"""
|
||||||
config_str = (
|
config_str = (
|
||||||
f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@"
|
f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@"
|
||||||
f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}|"
|
f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}@"
|
||||||
|
f"{cfg.dataset_exact_deduplication or False}|"
|
||||||
f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}"
|
f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}"
|
||||||
f"|{tokenizer_name}"
|
f"|{tokenizer_name}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1509,3 +1509,16 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
"dataset_exact_deduplication is not available for streaming datasets. "
|
"dataset_exact_deduplication is not available for streaming datasets. "
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_deduplication_with_skip_prepare(cls, data):
|
||||||
|
if data.get("dataset_exact_deduplication") and data.get("skip_prepare_dataset"):
|
||||||
|
raise ValueError(
|
||||||
|
"dataset_exact_deduplication=True has no effect when "
|
||||||
|
"skip_prepare_dataset=True. Deduplication runs as part of the "
|
||||||
|
"prepare pipeline, which is skipped. Either set "
|
||||||
|
"skip_prepare_dataset: false or disable "
|
||||||
|
"dataset_exact_deduplication."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|||||||
210
tests/test_save_deduplicated.py
Normal file
210
tests/test_save_deduplicated.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
"""Tests to verify that deduplication runs before dataset saving during preprocessing.
|
||||||
|
|
||||||
|
This addresses GitHub issue #2719: Save De-duplicated Set During Pre-processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
class TestSFTSaveDeduplicatedBeforeSave:
|
||||||
|
"""Verify that in SFT data loading, deduplication occurs before saving."""
|
||||||
|
|
||||||
|
@patch("axolotl.utils.data.sft.save_preprocessed_dataset")
|
||||||
|
@patch("axolotl.utils.data.sft.generate_dataset_hash_from_config")
|
||||||
|
@patch("axolotl.utils.data.sft.deduplicate_and_log_datasets")
|
||||||
|
@patch("axolotl.utils.data.sft.merge_datasets")
|
||||||
|
@patch("axolotl.utils.data.sft._load_and_process_single_dataset")
|
||||||
|
@patch("axolotl.utils.data.sft.datasets_with_name_generator")
|
||||||
|
def test_dedup_called_before_save_sft(
|
||||||
|
self,
|
||||||
|
mock_datasets_gen,
|
||||||
|
mock_load_single,
|
||||||
|
mock_merge,
|
||||||
|
mock_dedup,
|
||||||
|
mock_gen_hash,
|
||||||
|
mock_save,
|
||||||
|
):
|
||||||
|
"""Deduplication should be called before save_preprocessed_dataset in SFT."""
|
||||||
|
from axolotl.utils.data.sft import _load_raw_datasets
|
||||||
|
|
||||||
|
# Set up mock data
|
||||||
|
dataset = Dataset.from_dict({"text": ["a", "b", "a"], "label": [1, 2, 1]})
|
||||||
|
deduped_dataset = Dataset.from_dict({"text": ["a", "b"], "label": [1, 2]})
|
||||||
|
|
||||||
|
mock_datasets_gen.return_value = [
|
||||||
|
DictDefault({"path": "test", "type": "alpaca"})
|
||||||
|
]
|
||||||
|
mock_load_single.return_value = (dataset, None)
|
||||||
|
mock_merge.return_value = dataset
|
||||||
|
mock_dedup.return_value = (deduped_dataset, None)
|
||||||
|
mock_gen_hash.return_value = "testhash"
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"skip_prepare_dataset": False,
|
||||||
|
"dataset_exact_deduplication": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"eval_sequence_len": None,
|
||||||
|
"sample_packing": False,
|
||||||
|
"is_preprocess": False,
|
||||||
|
"seed": 42,
|
||||||
|
"datasets": [{"path": "test", "type": "alpaca"}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = MagicMock()
|
||||||
|
tokenizer.name_or_path = "test-tokenizer"
|
||||||
|
|
||||||
|
# Track call order
|
||||||
|
call_order = []
|
||||||
|
mock_dedup.side_effect = lambda **kwargs: (
|
||||||
|
call_order.append("dedup") or (deduped_dataset, None)
|
||||||
|
)
|
||||||
|
mock_save.side_effect = lambda *args, **kwargs: call_order.append("save")
|
||||||
|
|
||||||
|
_load_raw_datasets(
|
||||||
|
cfg=cfg,
|
||||||
|
datasets_configs=cfg.datasets,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify dedup was called
|
||||||
|
assert "dedup" in call_order, "Deduplication should have been called"
|
||||||
|
# Verify save was called
|
||||||
|
assert "save" in call_order, "Save should have been called"
|
||||||
|
# Verify dedup happened before save
|
||||||
|
assert call_order.index("dedup") < call_order.index("save"), (
|
||||||
|
"Deduplication must occur before saving the dataset"
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("axolotl.utils.data.sft.save_preprocessed_dataset")
|
||||||
|
@patch("axolotl.utils.data.sft.generate_dataset_hash_from_config")
|
||||||
|
@patch("axolotl.utils.data.sft.merge_datasets")
|
||||||
|
@patch("axolotl.utils.data.sft._load_and_process_single_dataset")
|
||||||
|
@patch("axolotl.utils.data.sft.datasets_with_name_generator")
|
||||||
|
def test_no_dedup_when_disabled_sft(
|
||||||
|
self,
|
||||||
|
mock_datasets_gen,
|
||||||
|
mock_load_single,
|
||||||
|
mock_merge,
|
||||||
|
mock_gen_hash,
|
||||||
|
mock_save,
|
||||||
|
):
|
||||||
|
"""Deduplication should not be called when dataset_exact_deduplication is False."""
|
||||||
|
from axolotl.utils.data.sft import _load_raw_datasets
|
||||||
|
|
||||||
|
dataset = Dataset.from_dict({"text": ["a", "b", "a"], "label": [1, 2, 1]})
|
||||||
|
|
||||||
|
mock_datasets_gen.return_value = [
|
||||||
|
DictDefault({"path": "test", "type": "alpaca"})
|
||||||
|
]
|
||||||
|
mock_load_single.return_value = (dataset, None)
|
||||||
|
mock_merge.return_value = dataset
|
||||||
|
mock_gen_hash.return_value = "testhash"
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"skip_prepare_dataset": False,
|
||||||
|
"dataset_exact_deduplication": False,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"eval_sequence_len": None,
|
||||||
|
"sample_packing": False,
|
||||||
|
"is_preprocess": False,
|
||||||
|
"seed": 42,
|
||||||
|
"datasets": [{"path": "test", "type": "alpaca"}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = MagicMock()
|
||||||
|
tokenizer.name_or_path = "test-tokenizer"
|
||||||
|
|
||||||
|
with patch("axolotl.utils.data.sft.deduplicate_and_log_datasets") as mock_dedup:
|
||||||
|
_load_raw_datasets(
|
||||||
|
cfg=cfg,
|
||||||
|
datasets_configs=cfg.datasets,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
mock_dedup.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestRLSaveDeduplicatedBeforeSave:
|
||||||
|
"""Verify that in RL data loading, deduplication occurs before saving."""
|
||||||
|
|
||||||
|
@patch.object(Dataset, "filter", lambda self, *args, **kwargs: self)
|
||||||
|
@patch("axolotl.utils.data.rl.save_preprocessed_dataset")
|
||||||
|
@patch("axolotl.utils.data.rl.generate_dataset_hash_from_config")
|
||||||
|
@patch("axolotl.utils.data.rl.deduplicate_and_log_datasets")
|
||||||
|
@patch("axolotl.utils.data.rl.merge_datasets")
|
||||||
|
@patch("axolotl.utils.data.rl.load_dataset_with_config")
|
||||||
|
@patch("axolotl.utils.data.rl.datasets_with_name_generator")
|
||||||
|
@patch("axolotl.utils.data.rl.load_tokenizer")
|
||||||
|
def test_dedup_called_before_save_rl(
|
||||||
|
self,
|
||||||
|
mock_load_tokenizer,
|
||||||
|
mock_datasets_gen,
|
||||||
|
mock_load_dataset,
|
||||||
|
mock_merge,
|
||||||
|
mock_dedup,
|
||||||
|
mock_gen_hash,
|
||||||
|
mock_save,
|
||||||
|
):
|
||||||
|
"""Deduplication should be called before save_preprocessed_dataset in RL."""
|
||||||
|
from axolotl.utils.data.rl import _load_split
|
||||||
|
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"prompt": ["hi", "bye", "hi"],
|
||||||
|
"chosen": ["a", "b", "a"],
|
||||||
|
"rejected": ["c", "d", "c"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
deduped_dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"prompt": ["hi", "bye"],
|
||||||
|
"chosen": ["a", "b"],
|
||||||
|
"rejected": ["c", "d"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_datasets_gen.return_value = [DictDefault({"path": "test", "type": None})]
|
||||||
|
mock_load_dataset.return_value = dataset
|
||||||
|
mock_merge.return_value = dataset
|
||||||
|
mock_dedup.return_value = (deduped_dataset, None)
|
||||||
|
mock_gen_hash.return_value = "testhash"
|
||||||
|
|
||||||
|
tokenizer = MagicMock()
|
||||||
|
tokenizer.name_or_path = "test-tokenizer"
|
||||||
|
mock_load_tokenizer.return_value = tokenizer
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"skip_prepare_dataset": False,
|
||||||
|
"dataset_exact_deduplication": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"rl": "dpo",
|
||||||
|
"datasets": [{"path": "test", "type": None}],
|
||||||
|
"hf_use_auth_token": False,
|
||||||
|
"dataset_num_proc": 1,
|
||||||
|
"is_preprocess": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
call_order = []
|
||||||
|
mock_dedup.side_effect = lambda **kwargs: (
|
||||||
|
call_order.append("dedup") or (deduped_dataset, None)
|
||||||
|
)
|
||||||
|
mock_save.side_effect = lambda *args, **kwargs: call_order.append("save")
|
||||||
|
|
||||||
|
_load_split(cfg, split="train")
|
||||||
|
|
||||||
|
assert "dedup" in call_order, "Deduplication should have been called"
|
||||||
|
assert "save" in call_order, "Save should have been called"
|
||||||
|
assert call_order.index("dedup") < call_order.index("save"), (
|
||||||
|
"Deduplication must occur before saving the dataset"
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user