22 lines
675 B
Python
Executable File
22 lines
675 B
Python
Executable File
"""
|
|
helper util to calculate dataset lengths
|
|
"""
|
|
|
|
import numpy as np
|
|
|
|
|
|
def get_dataset_lengths(dataset, from_arrow=False):
|
|
if "length" in dataset.column_names:
|
|
lengths = np.array(dataset["length"])
|
|
elif "position_ids" in dataset.column_names:
|
|
position_ids = dataset["position_ids"]
|
|
lengths = np.array([x[-1] + 1 for x in position_ids])
|
|
else:
|
|
if from_arrow:
|
|
input_ids = dataset.data.column("input_ids")
|
|
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
|
else:
|
|
input_ids = dataset["input_ids"]
|
|
lengths = np.array([len(seq) for seq in input_ids])
|
|
return lengths
|