diff --git a/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py b/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py index df8d106fd..94f97cf94 100644 --- a/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py +++ b/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py @@ -9,6 +9,9 @@ from torch.utils.data._utils.worker import _worker_loop class _MapDatasetFetcher(_BaseDatasetFetcher): def fetch(self, possibly_batched_index): + if not possibly_batched_index: + return self.collate_fn([]) + if isinstance(possibly_batched_index[0], list): data = [None for i in possibly_batched_index] for i, possibly_batched_index_ in enumerate(possibly_batched_index):