From 27fec49083908a6fe941b70dee075c663d806050 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 9 May 2025 20:28:58 -0400 Subject: [PATCH] don't sort multipack sampler (#2657) * don't sort multipack sampler * increased packing efficiency increases loss --------- Co-authored-by: Wing Lian --- src/axolotl/utils/samplers/multipack.py | 8 ++------ tests/e2e/integrations/test_kd.py | 4 ++-- tests/test_packed_batch_sampler.py | 1 + 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index a0c30b0d4..c38313c7c 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -78,15 +78,11 @@ def pack_group( Returns: List of bins, where each bin contains indices of sequences assigned to it """ - # Get sorting indices and sort lengths in descending order - indices = np.argsort(sequence_lengths)[::-1] - sorted_lengths = sequence_lengths[indices] - bins_remaining_space: list = [] # Tracks remaining capacity in each bin bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin - for seq_id, size in enumerate(sorted_lengths): - global_idx = indices[seq_id] + group_offset + for seq_id, size in enumerate(sequence_lengths): + global_idx = seq_id + group_offset # Try to place sequence in existing bins add_new_bin = True diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py index 9bfe5aaef..f36eef953 100644 --- a/tests/e2e/integrations/test_kd.py +++ b/tests/e2e/integrations/test_kd.py @@ -90,7 +90,7 @@ class TestKnowledgeDistillation: train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "model.safetensors").exists() check_tensorboard( - temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high" + temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high" ) @pytest.mark.parametrize( @@ -121,5 +121,5 @@ class TestKnowledgeDistillation: train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "adapter_model.safetensors").exists() check_tensorboard( - temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high" + temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high" ) diff --git a/tests/test_packed_batch_sampler.py b/tests/test_packed_batch_sampler.py index dd0386e58..2b03c62f8 100644 --- a/tests/test_packed_batch_sampler.py +++ b/tests/test_packed_batch_sampler.py @@ -106,3 +106,4 @@ class TestBatchedSamplerPacking: original_idxs = set(range(len(train_dataset))) assert original_idxs == set(batch_idxs) + assert len(batch_idxs) == len(set(batch_idxs))