Fix sample packing producing longer sequences than specified by sequence_len (#2332)

* Extend MultiPackBatchSampler test to include shorter sequence length and drop long sequences filter

* Fix get_dataset_lengths for datasets that were previously filtered (e.g., with drop_long_seq_in_dataset)

* Update src/axolotl/utils/samplers/utils.py

Fix get_dataset_lengths for datasets that do not have position_ids or length attributes

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
Tobias
2025-02-19 06:02:35 +01:00
committed by GitHub
parent 23a9fcb0a7
commit 8dfadc2b3c
2 changed files with 12 additions and 12 deletions

View File

@@ -5,12 +5,12 @@ import numpy as np
def get_dataset_lengths(dataset):
if "length" in dataset.data.column_names:
lengths = np.array(dataset.data.column("length"))
elif "position_ids" in dataset.data.column_names:
position_ids = dataset.data.column("position_ids")
if "length" in dataset.column_names:
lengths = np.array(dataset["length"])
elif "position_ids" in dataset.column_names:
position_ids = dataset["position_ids"]
lengths = np.array([x[-1] + 1 for x in position_ids])
else:
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
return lengths