From 32637fad0035ddc6839d517d5b9b6b468825ab13 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 17 Apr 2025 05:02:35 +0700 Subject: [PATCH] fix: preprocess yielding whole dataset to each worker (#2503) [skip ci] --- src/axolotl/utils/data/sft.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 726ec4858..35eab30c5 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -332,16 +332,23 @@ def load_tokenized_prepared_datasets( if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") if isinstance(dataset, IterableDataset): + num_workers = cfg.dataset_processes - def gen_from_iter_ds(_ds, _=None): - yield from _ds + def gen_from_iter_ds(_ds, worker_id: List[int], num_workers: List[int]): + """Generator function to correctly splice the dataset for each worker""" + for i, item in enumerate(_ds): + if i % num_workers[0] == worker_id[0]: + yield item ds_from_iter = Dataset.from_generator( functools.partial(gen_from_iter_ds, dataset), features=dataset.features, - num_proc=cfg.dataset_processes, + num_proc=num_workers, split=split, - gen_kwargs={"_": list(range(cfg.dataset_processes))}, + gen_kwargs={ + "worker_id": list(range(num_workers)), + "num_workers": [num_workers] * num_workers, + }, ) ds_from_iter.save_to_disk(str(prepared_ds_path)) else: