fix: Save de-duplicated dataset during pre-processing (#3427)

* fix: run deduplication before saving dataset during preprocessing

Move deduplicate_and_log_datasets call before save_preprocessed_dataset
in both SFT and RL data loading pipelines. This ensures the saved
preprocessed dataset is already de-duplicated, so subsequent loads
from cache don't contain duplicates.

Fixes #2719

* fix: include deduplication flag in dataset hash and warn on skip_prepare_dataset+dedup

- Add dataset_exact_deduplication to the hash string in
  generate_dataset_hash_from_config so cached datasets are invalidated
  when the dedup setting changes.
- Log a warning when skip_prepare_dataset=True and
  dataset_exact_deduplication=True, since dedup will be silently
  skipped in that configuration (both SFT and RL paths).

* fix: add ValueError for skip_prepare+dedup, fix test mock target and formatting

- Add config validator (check_deduplication_with_skip_prepare) that raises
  ValueError when skip_prepare_dataset=True and dataset_exact_deduplication=True
- Replace runtime warnings in sft.py/rl.py with the validator check
- Fix RL test: patch axolotl.utils.data.rl.load_tokenizer instead of
  axolotl.loaders.load_tokenizer to properly mock the imported reference
- Fix ruff lint (remove unused imports) and formatting issues

* refactor: inline deduplicate function per review feedback

* fix test fixture, lint

---------

Co-authored-by: ManasVardhan <manasvardhan@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Manas Vardhan
2026-03-02 09:55:59 -08:00
committed by GitHub
parent 444020b332
commit 474208b794
5 changed files with 237 additions and 21 deletions

View File

@@ -246,6 +246,10 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
dataset = merge_datasets(split_datasets, cfg)
if not cfg.skip_prepare_dataset:
# Deduplicate before saving so the saved dataset is already de-duplicated
if cfg.dataset_exact_deduplication:
dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
# Save preprocessed dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path

View File

@@ -351,6 +351,10 @@ def _load_raw_datasets(
if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
# Deduplicate before saving so the saved dataset is already de-duplicated
if cfg.dataset_exact_deduplication:
dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
# Save the prepared dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
@@ -438,25 +442,8 @@ def _handle_train_dataset_split(
)
return train_dataset, eval_dataset
# No validation split - apply deduplication if needed and return as train dataset
if cfg.dataset_exact_deduplication:
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
else:
train_dataset = dataset
return train_dataset, None
def _handle_test_dataset_split(
dataset: Dataset, cfg: DictDefault
) -> tuple[None, Dataset | None]:
"""Handle processing for test split."""
if cfg.dataset_exact_deduplication:
eval_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
else:
eval_dataset = dataset
return None, eval_dataset
# No validation split - deduplication already applied during preprocessing
return dataset, None
def _apply_dataset_sharding(dataset: Dataset, cfg: DictDefault) -> Dataset:
@@ -515,6 +502,7 @@ def _load_and_prepare_datasets(
if split == "train":
train_dataset, eval_dataset = _handle_train_dataset_split(dataset, cfg)
else:
train_dataset, eval_dataset = _handle_test_dataset_split(dataset, cfg)
# Deduplication already applied during preprocessing
train_dataset, eval_dataset = None, dataset
return train_dataset, eval_dataset, prompters

View File

@@ -520,7 +520,8 @@ def generate_dataset_hash_from_config(
"""
config_str = (
f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@"
f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}|"
f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}@"
f"{cfg.dataset_exact_deduplication or False}|"
f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}"
f"|{tokenizer_name}"
)

View File

@@ -1509,3 +1509,16 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
"dataset_exact_deduplication is not available for streaming datasets. "
)
return data
@model_validator(mode="before")
@classmethod
def check_deduplication_with_skip_prepare(cls, data):
if data.get("dataset_exact_deduplication") and data.get("skip_prepare_dataset"):
raise ValueError(
"dataset_exact_deduplication=True has no effect when "
"skip_prepare_dataset=True. Deduplication runs as part of the "
"prepare pipeline, which is skipped. Either set "
"skip_prepare_dataset: false or disable "
"dataset_exact_deduplication."
)
return data