Compare commits
2 Commits
streaming
...
squash_pos
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
21ba1cd3f1 | ||
|
|
eea7a006e1 |
@@ -6,7 +6,6 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
|
||||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
|
||||||
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
||||||
from axolotl.loaders import load_processor, load_tokenizer
|
from axolotl.loaders import load_processor, load_tokenizer
|
||||||
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
|
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
|
||||||
|
|||||||
@@ -476,6 +476,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
|
if self.cfg.squash_position_ids:
|
||||||
|
kwargs["squash_position_ids"] = True
|
||||||
else:
|
else:
|
||||||
collator = BatchSamplerDataCollatorForSeq2Seq
|
collator = BatchSamplerDataCollatorForSeq2Seq
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -277,6 +277,14 @@ class PatchManager:
|
|||||||
has_remote_code=has_remote_code,
|
has_remote_code=has_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.cfg.sample_packing:
|
||||||
|
from axolotl.monkeypatch.data.batch_dataset_fetcher import (
|
||||||
|
apply_multipack_dataloader_patch,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("Applying multipack dataloader patch for sample packing...")
|
||||||
|
apply_multipack_dataloader_patch()
|
||||||
|
|
||||||
def _apply_fsdp2_bnb_patches(self):
|
def _apply_fsdp2_bnb_patches(self):
|
||||||
"""Apply FSDP2 BNB patches."""
|
"""Apply FSDP2 BNB patches."""
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
|
"""Monkey patches for the dataset fetcher to handle batches of packed indexes."""
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
|
||||||
@@ -6,10 +6,20 @@ import torch
|
|||||||
from torch.utils.data._utils.fetch import _BaseDatasetFetcher
|
from torch.utils.data._utils.fetch import _BaseDatasetFetcher
|
||||||
from torch.utils.data._utils.worker import _worker_loop
|
from torch.utils.data._utils.worker import _worker_loop
|
||||||
|
|
||||||
|
_ORIGINAL_MAP_DATASET_FETCHER = None
|
||||||
|
_ORIGINAL_WORKER_LOOP = None
|
||||||
|
_IS_PATCHED = False
|
||||||
|
|
||||||
|
|
||||||
class _MapDatasetFetcher(_BaseDatasetFetcher):
|
class _MapDatasetFetcher(_BaseDatasetFetcher):
|
||||||
|
"""
|
||||||
|
Custom dataset fetcher that handles nested batch structures from
|
||||||
|
MultipackBatchSampler.
|
||||||
|
"""
|
||||||
|
|
||||||
def fetch(self, possibly_batched_index):
|
def fetch(self, possibly_batched_index):
|
||||||
if isinstance(possibly_batched_index[0], list):
|
if isinstance(possibly_batched_index[0], list):
|
||||||
|
# Handle nested structure from MultipackBatchSampler
|
||||||
data = [None for i in possibly_batched_index]
|
data = [None for i in possibly_batched_index]
|
||||||
for i, possibly_batched_index_ in enumerate(possibly_batched_index):
|
for i, possibly_batched_index_ in enumerate(possibly_batched_index):
|
||||||
if self.auto_collation:
|
if self.auto_collation:
|
||||||
@@ -23,6 +33,7 @@ class _MapDatasetFetcher(_BaseDatasetFetcher):
|
|||||||
else:
|
else:
|
||||||
data[i] = self.dataset[possibly_batched_index_]
|
data[i] = self.dataset[possibly_batched_index_]
|
||||||
else:
|
else:
|
||||||
|
# Standard batch handling
|
||||||
if self.auto_collation:
|
if self.auto_collation:
|
||||||
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
|
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
|
||||||
data = self.dataset.__getitems__(possibly_batched_index)
|
data = self.dataset.__getitems__(possibly_batched_index)
|
||||||
@@ -34,14 +45,54 @@ class _MapDatasetFetcher(_BaseDatasetFetcher):
|
|||||||
|
|
||||||
|
|
||||||
def patch_fetchers():
|
def patch_fetchers():
|
||||||
|
"""Apply patches to PyTorch's DataLoader components."""
|
||||||
torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
||||||
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
||||||
|
|
||||||
|
|
||||||
def patched_worker_loop(*args, **kwargs):
|
def patched_worker_loop(*args, **kwargs):
|
||||||
|
"""Worker loop that ensures patches are applied in worker processes."""
|
||||||
patch_fetchers()
|
patch_fetchers()
|
||||||
return _worker_loop(*args, **kwargs)
|
return _worker_loop(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
torch.utils.data._utils.worker._worker_loop = patched_worker_loop
|
def apply_multipack_dataloader_patch():
|
||||||
patch_fetchers()
|
"""
|
||||||
|
This patch allows DataLoader to correctly process batches that contain multiple bins
|
||||||
|
of packed sequences.
|
||||||
|
"""
|
||||||
|
# pylint: disable=global-statement
|
||||||
|
global _ORIGINAL_MAP_DATASET_FETCHER, _ORIGINAL_WORKER_LOOP, _IS_PATCHED
|
||||||
|
|
||||||
|
if _IS_PATCHED:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Store original implementations
|
||||||
|
_ORIGINAL_MAP_DATASET_FETCHER = torch.utils.data._utils.fetch._MapDatasetFetcher
|
||||||
|
_ORIGINAL_WORKER_LOOP = torch.utils.data._utils.worker._worker_loop
|
||||||
|
|
||||||
|
# Apply patches
|
||||||
|
patch_fetchers()
|
||||||
|
torch.utils.data._utils.worker._worker_loop = patched_worker_loop
|
||||||
|
|
||||||
|
_IS_PATCHED = True
|
||||||
|
|
||||||
|
|
||||||
|
def remove_multipack_dataloader_patch():
|
||||||
|
"""Remove the monkeypatch and restore original PyTorch DataLoader behavior."""
|
||||||
|
# pylint: disable=global-statement
|
||||||
|
global _IS_PATCHED
|
||||||
|
|
||||||
|
if not _IS_PATCHED:
|
||||||
|
return
|
||||||
|
|
||||||
|
if _ORIGINAL_MAP_DATASET_FETCHER:
|
||||||
|
torch.utils.data._utils.fetch._MapDatasetFetcher = _ORIGINAL_MAP_DATASET_FETCHER
|
||||||
|
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = (
|
||||||
|
_ORIGINAL_MAP_DATASET_FETCHER
|
||||||
|
)
|
||||||
|
|
||||||
|
if _ORIGINAL_WORKER_LOOP:
|
||||||
|
torch.utils.data._utils.worker._worker_loop = _ORIGINAL_WORKER_LOOP
|
||||||
|
|
||||||
|
_IS_PATCHED = False
|
||||||
|
|||||||
@@ -459,6 +459,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'"
|
"description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
squash_position_ids: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Whether to squash position_ids for packing, effectively extending context length."
|
||||||
|
},
|
||||||
|
)
|
||||||
eval_sample_packing: bool | None = Field(
|
eval_sample_packing: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
@@ -48,7 +48,13 @@ class TestBatchedSamplerPacking:
|
|||||||
max_seq_length,
|
max_seq_length,
|
||||||
sequential,
|
sequential,
|
||||||
):
|
):
|
||||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
from axolotl.monkeypatch.data.batch_dataset_fetcher import (
|
||||||
|
apply_multipack_dataloader_patch,
|
||||||
|
remove_multipack_dataloader_patch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply the patch for multipack handling
|
||||||
|
apply_multipack_dataloader_patch()
|
||||||
|
|
||||||
dataset = dataset_winglian_tiny_shakespeare["train"]
|
dataset = dataset_winglian_tiny_shakespeare["train"]
|
||||||
|
|
||||||
@@ -101,10 +107,14 @@ class TestBatchedSamplerPacking:
|
|||||||
for pack in batch:
|
for pack in batch:
|
||||||
batch_idxs.extend(pack)
|
batch_idxs.extend(pack)
|
||||||
|
|
||||||
for batch in loader:
|
try:
|
||||||
assert batch["input_ids"].numel() <= batch_size * max_seq_length
|
for batch in loader:
|
||||||
assert batch["input_ids"].shape[1] == max_seq_length
|
assert batch["input_ids"].numel() <= batch_size * max_seq_length
|
||||||
|
assert batch["input_ids"].shape[1] == max_seq_length
|
||||||
|
|
||||||
original_idxs = set(range(len(train_dataset)))
|
original_idxs = set(range(len(train_dataset)))
|
||||||
assert original_idxs == set(batch_idxs)
|
assert original_idxs == set(batch_idxs)
|
||||||
assert len(batch_idxs) == len(set(batch_idxs))
|
assert len(batch_idxs) == len(set(batch_idxs))
|
||||||
|
finally:
|
||||||
|
# Clean up: remove the patch after the test
|
||||||
|
remove_multipack_dataloader_patch()
|
||||||
|
|||||||
Reference in New Issue
Block a user