diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index 3c58b4c85..a537c5b65 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -524,13 +524,24 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset: Merged dataset. """ if len(datasets) == 1: - return datasets[0] + ds = datasets[0] + + # Do not shuffle if curriculum sampling is enabled + if cfg.curriculum_sampling: + return ds + + return ds.shuffle(seed=cfg.seed) LOG.info("Merging datasets...") merged_dataset = concatenate_datasets(datasets) if cfg.shuffle_merged_datasets: LOG.debug("Shuffling merged datasets...") + if cfg.curriculum_sampling: + LOG.warning( + "Shuffling merged datasets with curriculum sampling is not recommended. " + "This will randomize the order of samples." + ) merged_dataset = merged_dataset.shuffle(seed=cfg.seed) else: LOG.debug("Not shuffling merged datasets.")