Switch to parallel FFD bin packing algorithm. (#1619)

* Switch to parallel FFD bin packing algorithm.

Add support for packing in a distributed context.
Add packing efficiency estimate back.

* revert changes to distributed code

* chore: lint

* fix config w new params for packing test

* add sample_packing_group_size and sample_packing_bin_size to cfg schema

* fix lamdbda function

* fix sampler/dataloader calculations for packing

---------

Co-authored-by: dsesclei <dave@sescleifer.com>
This commit is contained in:
Wing Lian
2024-05-23 17:32:14 -04:00
committed by GitHub
parent bbfed318bc
commit 367b2e879b
8 changed files with 175 additions and 225 deletions

View File

@@ -62,12 +62,14 @@ class TestBatchedSamplerPacking:
dataset,
)
train_dataset = concatenate_datasets([dataset_wrapper])
lengths = get_dataset_lengths(train_dataset)
batch_sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset),
lengths=lengths,
batch_size=batch_size,
drop_last=True,
batch_max_len=max_seq_length,
lengths=get_dataset_lengths(train_dataset),
group_size=100000,
bin_size=200,
)
loader = DataLoader(
@@ -81,19 +83,15 @@ class TestBatchedSamplerPacking:
),
num_workers=num_workers,
)
inputs = next(iter(loader))
assert inputs["input_ids"].shape == (batch_size, max_seq_length)
assert inputs["labels"].shape == (batch_size, max_seq_length)
assert inputs["attention_mask"].shape == (batch_size, max_seq_length)
batch_idxs = []
for batch in batch_sampler:
for pack in batch:
batch_idxs.extend(pack)
assert inputs["input_ids"].tolist()[0][0] == 2
assert inputs["labels"].tolist()[0][0] == -100
assert inputs["attention_mask"].tolist()[0][0] == 0
assert inputs["attention_mask"].tolist()[0][-1] > 1
for batch in loader:
assert len(batch["input_ids"]) <= batch_size * max_seq_length
assert batch["input_ids"].shape[1] == max_seq_length
if batch_size >= 2:
assert inputs["input_ids"].tolist()[1][0] == 2
assert inputs["labels"].tolist()[1][0] == -100
assert inputs["attention_mask"].tolist()[1][0] == 0
assert inputs["attention_mask"].tolist()[1][-1] > 1
original_idxs = set(range(len(train_dataset)))
assert original_idxs == set(batch_idxs)