Revert "Fix sample packing producing longer sequences than specified by `sequ…"

This reverts commit 8dfadc2b3c.
This commit is contained in:
Sung Ching Liu
2025-02-19 21:13:25 -05:00
committed by GitHub
parent 954e192f38
commit 90dfcd8c03
2 changed files with 12 additions and 12 deletions

View File

@@ -5,12 +5,12 @@ import numpy as np
def get_dataset_lengths(dataset):
if "length" in dataset.column_names:
lengths = np.array(dataset["length"])
elif "position_ids" in dataset.column_names:
position_ids = dataset["position_ids"]
if "length" in dataset.data.column_names:
lengths = np.array(dataset.data.column("length"))
elif "position_ids" in dataset.data.column_names:
position_ids = dataset.data.column("position_ids")
lengths = np.array([x[-1] + 1 for x in position_ids])
else:
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
return lengths