Fix: tokenize stall due to not shuffling dataset (#2845)
* fix: shuffle dataset even if only one to fix tokenize stall * fix: warn if shuffling merged with curriculum sampling * chore: refactor
This commit is contained in:
@@ -524,13 +524,24 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
|
|||||||
Merged dataset.
|
Merged dataset.
|
||||||
"""
|
"""
|
||||||
if len(datasets) == 1:
|
if len(datasets) == 1:
|
||||||
return datasets[0]
|
ds = datasets[0]
|
||||||
|
|
||||||
|
# Do not shuffle if curriculum sampling is enabled
|
||||||
|
if cfg.curriculum_sampling:
|
||||||
|
return ds
|
||||||
|
|
||||||
|
return ds.shuffle(seed=cfg.seed)
|
||||||
|
|
||||||
LOG.info("Merging datasets...")
|
LOG.info("Merging datasets...")
|
||||||
merged_dataset = concatenate_datasets(datasets)
|
merged_dataset = concatenate_datasets(datasets)
|
||||||
|
|
||||||
if cfg.shuffle_merged_datasets:
|
if cfg.shuffle_merged_datasets:
|
||||||
LOG.debug("Shuffling merged datasets...")
|
LOG.debug("Shuffling merged datasets...")
|
||||||
|
if cfg.curriculum_sampling:
|
||||||
|
LOG.warning(
|
||||||
|
"Shuffling merged datasets with curriculum sampling is not recommended. "
|
||||||
|
"This will randomize the order of samples."
|
||||||
|
)
|
||||||
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
|
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
|
||||||
else:
|
else:
|
||||||
LOG.debug("Not shuffling merged datasets.")
|
LOG.debug("Not shuffling merged datasets.")
|
||||||
|
|||||||
Reference in New Issue
Block a user