don't sort multipack sampler (#2657)
* don't sort multipack sampler * increased packing efficiency increases loss --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -78,15 +78,11 @@ def pack_group(
|
||||
Returns:
|
||||
List of bins, where each bin contains indices of sequences assigned to it
|
||||
"""
|
||||
# Get sorting indices and sort lengths in descending order
|
||||
indices = np.argsort(sequence_lengths)[::-1]
|
||||
sorted_lengths = sequence_lengths[indices]
|
||||
|
||||
bins_remaining_space: list = [] # Tracks remaining capacity in each bin
|
||||
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
|
||||
|
||||
for seq_id, size in enumerate(sorted_lengths):
|
||||
global_idx = indices[seq_id] + group_offset
|
||||
for seq_id, size in enumerate(sequence_lengths):
|
||||
global_idx = seq_id + group_offset
|
||||
|
||||
# Try to place sequence in existing bins
|
||||
add_new_bin = True
|
||||
|
||||
@@ -90,7 +90,7 @@ class TestKnowledgeDistillation:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -121,5 +121,5 @@ class TestKnowledgeDistillation:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@@ -106,3 +106,4 @@ class TestBatchedSamplerPacking:
|
||||
|
||||
original_idxs = set(range(len(train_dataset)))
|
||||
assert original_idxs == set(batch_idxs)
|
||||
assert len(batch_idxs) == len(set(batch_idxs))
|
||||
|
||||
Reference in New Issue
Block a user