make multipack sampler patch explicit (#3096)
* make multipack sampler patch explicit * combining
This commit is contained in:
@@ -48,7 +48,13 @@ class TestBatchedSamplerPacking:
|
||||
max_seq_length,
|
||||
sequential,
|
||||
):
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.monkeypatch.data.batch_dataset_fetcher import (
|
||||
apply_multipack_dataloader_patch,
|
||||
remove_multipack_dataloader_patch,
|
||||
)
|
||||
|
||||
# Apply the patch for multipack handling
|
||||
apply_multipack_dataloader_patch()
|
||||
|
||||
dataset = dataset_winglian_tiny_shakespeare["train"]
|
||||
|
||||
@@ -101,10 +107,14 @@ class TestBatchedSamplerPacking:
|
||||
for pack in batch:
|
||||
batch_idxs.extend(pack)
|
||||
|
||||
for batch in loader:
|
||||
assert batch["input_ids"].numel() <= batch_size * max_seq_length
|
||||
assert batch["input_ids"].shape[1] == max_seq_length
|
||||
try:
|
||||
for batch in loader:
|
||||
assert batch["input_ids"].numel() <= batch_size * max_seq_length
|
||||
assert batch["input_ids"].shape[1] == max_seq_length
|
||||
|
||||
original_idxs = set(range(len(train_dataset)))
|
||||
assert original_idxs == set(batch_idxs)
|
||||
assert len(batch_idxs) == len(set(batch_idxs))
|
||||
original_idxs = set(range(len(train_dataset)))
|
||||
assert original_idxs == set(batch_idxs)
|
||||
assert len(batch_idxs) == len(set(batch_idxs))
|
||||
finally:
|
||||
# Clean up: remove the patch after the test
|
||||
remove_multipack_dataloader_patch()
|
||||
|
||||
Reference in New Issue
Block a user