Sequential sample packing (#2404) [skip ci]

* add sequential sample packing

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
DreamGenX
2025-03-31 21:48:20 +02:00
committed by GitHub
parent 7acf93b59f
commit 4d36ecc724
7 changed files with 174 additions and 11 deletions

View File

@@ -38,8 +38,11 @@ class TestBatchedSamplerPacking:
],
)
@pytest.mark.parametrize("max_seq_length", [4096, 512])
@pytest.mark.parametrize("sequential", [True, False])
@enable_hf_offline
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
def test_packing(
self, batch_size, num_workers, tokenizer, max_seq_length, sequential
):
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
dataset = load_dataset(
@@ -75,6 +78,7 @@ class TestBatchedSamplerPacking:
batch_max_len=max_seq_length,
group_size=100000,
bin_size=200,
sequential=sequential,
)
loader = DataLoader(