Switch to parallel FFD bin packing algorithm. (#1619)
* Switch to parallel FFD bin packing algorithm. Add support for packing in a distributed context. Add packing efficiency estimate back. * revert changes to distributed code * chore: lint * fix config w new params for packing test * add sample_packing_group_size and sample_packing_bin_size to cfg schema * fix lamdbda function * fix sampler/dataloader calculations for packing --------- Co-authored-by: dsesclei <dave@sescleifer.com>
This commit is contained in:
@@ -62,12 +62,14 @@ class TestBatchedSamplerPacking:
|
||||
dataset,
|
||||
)
|
||||
train_dataset = concatenate_datasets([dataset_wrapper])
|
||||
lengths = get_dataset_lengths(train_dataset)
|
||||
batch_sampler = MultipackBatchSampler(
|
||||
sampler=RandomSampler(train_dataset),
|
||||
lengths=lengths,
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=max_seq_length,
|
||||
lengths=get_dataset_lengths(train_dataset),
|
||||
group_size=100000,
|
||||
bin_size=200,
|
||||
)
|
||||
|
||||
loader = DataLoader(
|
||||
@@ -81,19 +83,15 @@ class TestBatchedSamplerPacking:
|
||||
),
|
||||
num_workers=num_workers,
|
||||
)
|
||||
inputs = next(iter(loader))
|
||||
|
||||
assert inputs["input_ids"].shape == (batch_size, max_seq_length)
|
||||
assert inputs["labels"].shape == (batch_size, max_seq_length)
|
||||
assert inputs["attention_mask"].shape == (batch_size, max_seq_length)
|
||||
batch_idxs = []
|
||||
for batch in batch_sampler:
|
||||
for pack in batch:
|
||||
batch_idxs.extend(pack)
|
||||
|
||||
assert inputs["input_ids"].tolist()[0][0] == 2
|
||||
assert inputs["labels"].tolist()[0][0] == -100
|
||||
assert inputs["attention_mask"].tolist()[0][0] == 0
|
||||
assert inputs["attention_mask"].tolist()[0][-1] > 1
|
||||
for batch in loader:
|
||||
assert len(batch["input_ids"]) <= batch_size * max_seq_length
|
||||
assert batch["input_ids"].shape[1] == max_seq_length
|
||||
|
||||
if batch_size >= 2:
|
||||
assert inputs["input_ids"].tolist()[1][0] == 2
|
||||
assert inputs["labels"].tolist()[1][0] == -100
|
||||
assert inputs["attention_mask"].tolist()[1][0] == 0
|
||||
assert inputs["attention_mask"].tolist()[1][-1] > 1
|
||||
original_idxs = set(range(len(train_dataset)))
|
||||
assert original_idxs == set(batch_idxs)
|
||||
|
||||
Reference in New Issue
Block a user