revert seq len in multipack sampler

This commit is contained in:
Sunny
2025-01-14 11:45:35 -05:00
parent c06a6be915
commit dbcd11e533

View File

@@ -160,24 +160,19 @@ class MultipackBatchSampler(BatchSampler):
for i in range(0, len(batches), self.batch_size)
]
seq_lens = [
[[lengths[idx] for idx in sub_batch] for sub_batch in batch]
for batch in batches
]
# statistics
if set_stats:
self.eff_total_used += total_used
self.eff_total_slots += total_slots
return batches, seq_lens
return batches
def __iter__(self):
batches, _ = self.generate_batches(set_stats=True)
batches = self.generate_batches(set_stats=True)
return iter(batches)
def num_batches(self):
batches, _ = self.generate_batches(set_stats=True)
batches = self.generate_batches(set_stats=True)
return len(batches)
def efficiency(self):