fix: preprocess yielding whole dataset to each worker (#2503) [skip ci]
This commit is contained in:
@@ -332,16 +332,23 @@ def load_tokenized_prepared_datasets(
|
||||
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
|
||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||
if isinstance(dataset, IterableDataset):
|
||||
num_workers = cfg.dataset_processes
|
||||
|
||||
def gen_from_iter_ds(_ds, _=None):
|
||||
yield from _ds
|
||||
def gen_from_iter_ds(_ds, worker_id: List[int], num_workers: List[int]):
|
||||
"""Generator function to correctly splice the dataset for each worker"""
|
||||
for i, item in enumerate(_ds):
|
||||
if i % num_workers[0] == worker_id[0]:
|
||||
yield item
|
||||
|
||||
ds_from_iter = Dataset.from_generator(
|
||||
functools.partial(gen_from_iter_ds, dataset),
|
||||
features=dataset.features,
|
||||
num_proc=cfg.dataset_processes,
|
||||
num_proc=num_workers,
|
||||
split=split,
|
||||
gen_kwargs={"_": list(range(cfg.dataset_processes))},
|
||||
gen_kwargs={
|
||||
"worker_id": list(range(num_workers)),
|
||||
"num_workers": [num_workers] * num_workers,
|
||||
},
|
||||
)
|
||||
ds_from_iter.save_to_disk(str(prepared_ds_path))
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user