revert seq len in multipack sampler
This commit is contained in:
@@ -160,24 +160,19 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
for i in range(0, len(batches), self.batch_size)
|
for i in range(0, len(batches), self.batch_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
seq_lens = [
|
|
||||||
[[lengths[idx] for idx in sub_batch] for sub_batch in batch]
|
|
||||||
for batch in batches
|
|
||||||
]
|
|
||||||
|
|
||||||
# statistics
|
# statistics
|
||||||
if set_stats:
|
if set_stats:
|
||||||
self.eff_total_used += total_used
|
self.eff_total_used += total_used
|
||||||
self.eff_total_slots += total_slots
|
self.eff_total_slots += total_slots
|
||||||
|
|
||||||
return batches, seq_lens
|
return batches
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
batches, _ = self.generate_batches(set_stats=True)
|
batches = self.generate_batches(set_stats=True)
|
||||||
return iter(batches)
|
return iter(batches)
|
||||||
|
|
||||||
def num_batches(self):
|
def num_batches(self):
|
||||||
batches, _ = self.generate_batches(set_stats=True)
|
batches = self.generate_batches(set_stats=True)
|
||||||
return len(batches)
|
return len(batches)
|
||||||
|
|
||||||
def efficiency(self):
|
def efficiency(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user