make multipack sampler patch explicit (#3096)

* make multipack sampler patch explicit

* combining
This commit is contained in:
Dan Saunders
2025-08-22 14:29:10 -04:00
committed by GitHub
parent ab4d604a8f
commit eea7a006e1
4 changed files with 79 additions and 11 deletions

View File

@@ -48,7 +48,13 @@ class TestBatchedSamplerPacking:
max_seq_length,
sequential,
):
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.monkeypatch.data.batch_dataset_fetcher import (
apply_multipack_dataloader_patch,
remove_multipack_dataloader_patch,
)
# Apply the patch for multipack handling
apply_multipack_dataloader_patch()
dataset = dataset_winglian_tiny_shakespeare["train"]
@@ -101,10 +107,14 @@ class TestBatchedSamplerPacking:
for pack in batch:
batch_idxs.extend(pack)
for batch in loader:
assert batch["input_ids"].numel() <= batch_size * max_seq_length
assert batch["input_ids"].shape[1] == max_seq_length
try:
for batch in loader:
assert batch["input_ids"].numel() <= batch_size * max_seq_length
assert batch["input_ids"].shape[1] == max_seq_length
original_idxs = set(range(len(train_dataset)))
assert original_idxs == set(batch_idxs)
assert len(batch_idxs) == len(set(batch_idxs))
original_idxs = set(range(len(train_dataset)))
assert original_idxs == set(batch_idxs)
assert len(batch_idxs) == len(set(batch_idxs))
finally:
# Clean up: remove the patch after the test
remove_multipack_dataloader_patch()