don't sort multipack sampler (#2657)

* don't sort multipack sampler

* increased packing efficiency increases loss

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Dan Saunders
2025-05-09 20:28:58 -04:00
committed by Wing Lian
parent 8cda9e93c1
commit 27fec49083
3 changed files with 5 additions and 8 deletions

View File

@@ -78,15 +78,11 @@ def pack_group(
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
# Get sorting indices and sort lengths in descending order
indices = np.argsort(sequence_lengths)[::-1]
sorted_lengths = sequence_lengths[indices]
bins_remaining_space: list = [] # Tracks remaining capacity in each bin
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
for seq_id, size in enumerate(sorted_lengths):
global_idx = indices[seq_id] + group_offset
for seq_id, size in enumerate(sequence_lengths):
global_idx = seq_id + group_offset
# Try to place sequence in existing bins
add_new_bin = True