flex_attn sample packing WIP

This commit is contained in:
Sunny
2025-01-14 00:22:05 -05:00
parent d3a0cb5edb
commit c06a6be915
2 changed files with 9 additions and 4 deletions

View File

@@ -160,19 +160,24 @@ class MultipackBatchSampler(BatchSampler):
for i in range(0, len(batches), self.batch_size)
]
seq_lens = [
[[lengths[idx] for idx in sub_batch] for sub_batch in batch]
for batch in batches
]
# statistics
if set_stats:
self.eff_total_used += total_used
self.eff_total_slots += total_slots
return batches
return batches, seq_lens
def __iter__(self):
batches = self.generate_batches(set_stats=True)
batches, _ = self.generate_batches(set_stats=True)
return iter(batches)
def num_batches(self):
batches = self.generate_batches(set_stats=True)
batches, _ = self.generate_batches(set_stats=True)
return len(batches)
def efficiency(self):

View File

@@ -96,7 +96,7 @@ def disable_datasets_caching():
def add_position_ids(sample):
sample_len = len(sample["input_ids"])
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
sample["position_ids"] = torch.arange(sample_len)
sample["length"] = sample_len
return sample