diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 74ce1265b..db14a6819 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -160,24 +160,19 @@ class MultipackBatchSampler(BatchSampler): for i in range(0, len(batches), self.batch_size) ] - seq_lens = [ - [[lengths[idx] for idx in sub_batch] for sub_batch in batch] - for batch in batches - ] - # statistics if set_stats: self.eff_total_used += total_used self.eff_total_slots += total_slots - return batches, seq_lens + return batches def __iter__(self): - batches, _ = self.generate_batches(set_stats=True) + batches = self.generate_batches(set_stats=True) return iter(batches) def num_batches(self): - batches, _ = self.generate_batches(set_stats=True) + batches = self.generate_batches(set_stats=True) return len(batches) def efficiency(self):