Compare commits
2 Commits
streaming
...
squash_pos
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
21ba1cd3f1 | ||
|
|
eea7a006e1 |
@@ -6,7 +6,6 @@ from dataclasses import dataclass
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
|
||||
|
||||
@@ -476,6 +476,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
):
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
if self.cfg.squash_position_ids:
|
||||
kwargs["squash_position_ids"] = True
|
||||
else:
|
||||
collator = BatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
|
||||
@@ -277,6 +277,14 @@ class PatchManager:
|
||||
has_remote_code=has_remote_code,
|
||||
)
|
||||
|
||||
if self.cfg.sample_packing:
|
||||
from axolotl.monkeypatch.data.batch_dataset_fetcher import (
|
||||
apply_multipack_dataloader_patch,
|
||||
)
|
||||
|
||||
LOG.info("Applying multipack dataloader patch for sample packing...")
|
||||
apply_multipack_dataloader_patch()
|
||||
|
||||
def _apply_fsdp2_bnb_patches(self):
|
||||
"""Apply FSDP2 BNB patches."""
|
||||
if (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
|
||||
"""Monkey patches for the dataset fetcher to handle batches of packed indexes."""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
||||
@@ -6,10 +6,20 @@ import torch
|
||||
from torch.utils.data._utils.fetch import _BaseDatasetFetcher
|
||||
from torch.utils.data._utils.worker import _worker_loop
|
||||
|
||||
_ORIGINAL_MAP_DATASET_FETCHER = None
|
||||
_ORIGINAL_WORKER_LOOP = None
|
||||
_IS_PATCHED = False
|
||||
|
||||
|
||||
class _MapDatasetFetcher(_BaseDatasetFetcher):
|
||||
"""
|
||||
Custom dataset fetcher that handles nested batch structures from
|
||||
MultipackBatchSampler.
|
||||
"""
|
||||
|
||||
def fetch(self, possibly_batched_index):
|
||||
if isinstance(possibly_batched_index[0], list):
|
||||
# Handle nested structure from MultipackBatchSampler
|
||||
data = [None for i in possibly_batched_index]
|
||||
for i, possibly_batched_index_ in enumerate(possibly_batched_index):
|
||||
if self.auto_collation:
|
||||
@@ -23,6 +33,7 @@ class _MapDatasetFetcher(_BaseDatasetFetcher):
|
||||
else:
|
||||
data[i] = self.dataset[possibly_batched_index_]
|
||||
else:
|
||||
# Standard batch handling
|
||||
if self.auto_collation:
|
||||
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
|
||||
data = self.dataset.__getitems__(possibly_batched_index)
|
||||
@@ -34,14 +45,54 @@ class _MapDatasetFetcher(_BaseDatasetFetcher):
|
||||
|
||||
|
||||
def patch_fetchers():
|
||||
"""Apply patches to PyTorch's DataLoader components."""
|
||||
torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
||||
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
||||
|
||||
|
||||
def patched_worker_loop(*args, **kwargs):
|
||||
"""Worker loop that ensures patches are applied in worker processes."""
|
||||
patch_fetchers()
|
||||
return _worker_loop(*args, **kwargs)
|
||||
|
||||
|
||||
torch.utils.data._utils.worker._worker_loop = patched_worker_loop
|
||||
patch_fetchers()
|
||||
def apply_multipack_dataloader_patch():
|
||||
"""
|
||||
This patch allows DataLoader to correctly process batches that contain multiple bins
|
||||
of packed sequences.
|
||||
"""
|
||||
# pylint: disable=global-statement
|
||||
global _ORIGINAL_MAP_DATASET_FETCHER, _ORIGINAL_WORKER_LOOP, _IS_PATCHED
|
||||
|
||||
if _IS_PATCHED:
|
||||
return
|
||||
|
||||
# Store original implementations
|
||||
_ORIGINAL_MAP_DATASET_FETCHER = torch.utils.data._utils.fetch._MapDatasetFetcher
|
||||
_ORIGINAL_WORKER_LOOP = torch.utils.data._utils.worker._worker_loop
|
||||
|
||||
# Apply patches
|
||||
patch_fetchers()
|
||||
torch.utils.data._utils.worker._worker_loop = patched_worker_loop
|
||||
|
||||
_IS_PATCHED = True
|
||||
|
||||
|
||||
def remove_multipack_dataloader_patch():
|
||||
"""Remove the monkeypatch and restore original PyTorch DataLoader behavior."""
|
||||
# pylint: disable=global-statement
|
||||
global _IS_PATCHED
|
||||
|
||||
if not _IS_PATCHED:
|
||||
return
|
||||
|
||||
if _ORIGINAL_MAP_DATASET_FETCHER:
|
||||
torch.utils.data._utils.fetch._MapDatasetFetcher = _ORIGINAL_MAP_DATASET_FETCHER
|
||||
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = (
|
||||
_ORIGINAL_MAP_DATASET_FETCHER
|
||||
)
|
||||
|
||||
if _ORIGINAL_WORKER_LOOP:
|
||||
torch.utils.data._utils.worker._worker_loop = _ORIGINAL_WORKER_LOOP
|
||||
|
||||
_IS_PATCHED = False
|
||||
|
||||
@@ -459,6 +459,12 @@ class AxolotlInputConfig(
|
||||
"description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'"
|
||||
},
|
||||
)
|
||||
squash_position_ids: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to squash position_ids for packing, effectively extending context length."
|
||||
},
|
||||
)
|
||||
eval_sample_packing: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -48,7 +48,13 @@ class TestBatchedSamplerPacking:
|
||||
max_seq_length,
|
||||
sequential,
|
||||
):
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.monkeypatch.data.batch_dataset_fetcher import (
|
||||
apply_multipack_dataloader_patch,
|
||||
remove_multipack_dataloader_patch,
|
||||
)
|
||||
|
||||
# Apply the patch for multipack handling
|
||||
apply_multipack_dataloader_patch()
|
||||
|
||||
dataset = dataset_winglian_tiny_shakespeare["train"]
|
||||
|
||||
@@ -101,10 +107,14 @@ class TestBatchedSamplerPacking:
|
||||
for pack in batch:
|
||||
batch_idxs.extend(pack)
|
||||
|
||||
for batch in loader:
|
||||
assert batch["input_ids"].numel() <= batch_size * max_seq_length
|
||||
assert batch["input_ids"].shape[1] == max_seq_length
|
||||
try:
|
||||
for batch in loader:
|
||||
assert batch["input_ids"].numel() <= batch_size * max_seq_length
|
||||
assert batch["input_ids"].shape[1] == max_seq_length
|
||||
|
||||
original_idxs = set(range(len(train_dataset)))
|
||||
assert original_idxs == set(batch_idxs)
|
||||
assert len(batch_idxs) == len(set(batch_idxs))
|
||||
original_idxs = set(range(len(train_dataset)))
|
||||
assert original_idxs == set(batch_idxs)
|
||||
assert len(batch_idxs) == len(set(batch_idxs))
|
||||
finally:
|
||||
# Clean up: remove the patch after the test
|
||||
remove_multipack_dataloader_patch()
|
||||
|
||||
Reference in New Issue
Block a user