diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index 761317dfb..0ff52ebe1 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -6,7 +6,6 @@ from dataclasses import dataclass from datasets import Dataset -import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs from axolotl.loaders import load_processor, load_tokenizer from axolotl.utils.data import prepare_datasets, prepare_preference_datasets diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 628d897d0..4959bd6ba 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -277,6 +277,14 @@ class PatchManager: has_remote_code=has_remote_code, ) + if self.cfg.sample_packing: + from axolotl.monkeypatch.data.batch_dataset_fetcher import ( + apply_multipack_dataloader_patch, + ) + + LOG.info("Applying multipack dataloader patch for sample packing...") + apply_multipack_dataloader_patch() + def _apply_fsdp2_bnb_patches(self): """Apply FSDP2 BNB patches.""" if ( diff --git a/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py b/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py index df8d106fd..73bf37b61 100644 --- a/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py +++ b/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py @@ -1,4 +1,4 @@ -"""monkey patches for the dataset fetcher to handle batches of packed indexes""" +"""Monkey patches for the dataset fetcher to handle batches of packed indexes.""" # pylint: disable=protected-access @@ -6,10 +6,20 @@ import torch from torch.utils.data._utils.fetch import _BaseDatasetFetcher from torch.utils.data._utils.worker import _worker_loop +_ORIGINAL_MAP_DATASET_FETCHER = None +_ORIGINAL_WORKER_LOOP = None +_IS_PATCHED = False + class _MapDatasetFetcher(_BaseDatasetFetcher): + """ + Custom dataset fetcher that handles nested batch structures from + MultipackBatchSampler. + """ + def fetch(self, possibly_batched_index): if isinstance(possibly_batched_index[0], list): + # Handle nested structure from MultipackBatchSampler data = [None for i in possibly_batched_index] for i, possibly_batched_index_ in enumerate(possibly_batched_index): if self.auto_collation: @@ -23,6 +33,7 @@ class _MapDatasetFetcher(_BaseDatasetFetcher): else: data[i] = self.dataset[possibly_batched_index_] else: + # Standard batch handling if self.auto_collation: if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__: data = self.dataset.__getitems__(possibly_batched_index) @@ -34,14 +45,54 @@ class _MapDatasetFetcher(_BaseDatasetFetcher): def patch_fetchers(): + """Apply patches to PyTorch's DataLoader components.""" torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher def patched_worker_loop(*args, **kwargs): + """Worker loop that ensures patches are applied in worker processes.""" patch_fetchers() return _worker_loop(*args, **kwargs) -torch.utils.data._utils.worker._worker_loop = patched_worker_loop -patch_fetchers() +def apply_multipack_dataloader_patch(): + """ + This patch allows DataLoader to correctly process batches that contain multiple bins + of packed sequences. + """ + # pylint: disable=global-statement + global _ORIGINAL_MAP_DATASET_FETCHER, _ORIGINAL_WORKER_LOOP, _IS_PATCHED + + if _IS_PATCHED: + return + + # Store original implementations + _ORIGINAL_MAP_DATASET_FETCHER = torch.utils.data._utils.fetch._MapDatasetFetcher + _ORIGINAL_WORKER_LOOP = torch.utils.data._utils.worker._worker_loop + + # Apply patches + patch_fetchers() + torch.utils.data._utils.worker._worker_loop = patched_worker_loop + + _IS_PATCHED = True + + +def remove_multipack_dataloader_patch(): + """Remove the monkeypatch and restore original PyTorch DataLoader behavior.""" + # pylint: disable=global-statement + global _IS_PATCHED + + if not _IS_PATCHED: + return + + if _ORIGINAL_MAP_DATASET_FETCHER: + torch.utils.data._utils.fetch._MapDatasetFetcher = _ORIGINAL_MAP_DATASET_FETCHER + torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = ( + _ORIGINAL_MAP_DATASET_FETCHER + ) + + if _ORIGINAL_WORKER_LOOP: + torch.utils.data._utils.worker._worker_loop = _ORIGINAL_WORKER_LOOP + + _IS_PATCHED = False diff --git a/tests/test_packed_batch_sampler.py b/tests/test_packed_batch_sampler.py index 47894a35b..d839c6ea3 100644 --- a/tests/test_packed_batch_sampler.py +++ b/tests/test_packed_batch_sampler.py @@ -48,7 +48,13 @@ class TestBatchedSamplerPacking: max_seq_length, sequential, ): - import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 + from axolotl.monkeypatch.data.batch_dataset_fetcher import ( + apply_multipack_dataloader_patch, + remove_multipack_dataloader_patch, + ) + + # Apply the patch for multipack handling + apply_multipack_dataloader_patch() dataset = dataset_winglian_tiny_shakespeare["train"] @@ -101,10 +107,14 @@ class TestBatchedSamplerPacking: for pack in batch: batch_idxs.extend(pack) - for batch in loader: - assert batch["input_ids"].numel() <= batch_size * max_seq_length - assert batch["input_ids"].shape[1] == max_seq_length + try: + for batch in loader: + assert batch["input_ids"].numel() <= batch_size * max_seq_length + assert batch["input_ids"].shape[1] == max_seq_length - original_idxs = set(range(len(train_dataset))) - assert original_idxs == set(batch_idxs) - assert len(batch_idxs) == len(set(batch_idxs)) + original_idxs = set(range(len(train_dataset))) + assert original_idxs == set(batch_idxs) + assert len(batch_idxs) == len(set(batch_idxs)) + finally: + # Clean up: remove the patch after the test + remove_multipack_dataloader_patch()