don't sort multipack sampler (#2657)

* don't sort multipack sampler

* increased packing efficiency increases loss

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Dan Saunders
2025-05-09 20:28:58 -04:00
committed by Wing Lian
parent 8cda9e93c1
commit 27fec49083
3 changed files with 5 additions and 8 deletions

View File

@@ -78,15 +78,11 @@ def pack_group(
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
# Get sorting indices and sort lengths in descending order
indices = np.argsort(sequence_lengths)[::-1]
sorted_lengths = sequence_lengths[indices]
bins_remaining_space: list = [] # Tracks remaining capacity in each bin
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
for seq_id, size in enumerate(sorted_lengths):
global_idx = indices[seq_id] + group_offset
for seq_id, size in enumerate(sequence_lengths):
global_idx = seq_id + group_offset
# Try to place sequence in existing bins
add_new_bin = True

View File

@@ -90,7 +90,7 @@ class TestKnowledgeDistillation:
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -121,5 +121,5 @@ class TestKnowledgeDistillation:
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
)

View File

@@ -106,3 +106,4 @@ class TestBatchedSamplerPacking:
original_idxs = set(range(len(train_dataset)))
assert original_idxs == set(batch_idxs)
assert len(batch_idxs) == len(set(batch_idxs))