fix: preprocess yielding whole dataset to each worker (#2503) [skip ci]

This commit is contained in:
NanoCode012
2025-04-17 05:02:35 +07:00
committed by GitHub
parent f776f889a1
commit 32637fad00

View File

@@ -332,16 +332,23 @@ def load_tokenized_prepared_datasets(
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
if isinstance(dataset, IterableDataset):
num_workers = cfg.dataset_processes
def gen_from_iter_ds(_ds, _=None):
yield from _ds
def gen_from_iter_ds(_ds, worker_id: List[int], num_workers: List[int]):
"""Generator function to correctly splice the dataset for each worker"""
for i, item in enumerate(_ds):
if i % num_workers[0] == worker_id[0]:
yield item
ds_from_iter = Dataset.from_generator(
functools.partial(gen_from_iter_ds, dataset),
features=dataset.features,
num_proc=cfg.dataset_processes,
num_proc=num_workers,
split=split,
gen_kwargs={"_": list(range(cfg.dataset_processes))},
gen_kwargs={
"worker_id": list(range(num_workers)),
"num_workers": [num_workers] * num_workers,
},
)
ds_from_iter.save_to_disk(str(prepared_ds_path))
else: