Fix sample packing producing longer sequences than specified by sequence_len (#2332)
* Extend MultiPackBatchSampler test to include shorter sequence length and drop long sequences filter * Fix get_dataset_lengths for datasets that were previously filtered (e.g., with drop_long_seq_in_dataset) * Update src/axolotl/utils/samplers/utils.py Fix get_dataset_lengths for datasets that do not have position_ids or length attributes Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
@@ -5,12 +5,12 @@ import numpy as np
|
||||
|
||||
|
||||
def get_dataset_lengths(dataset):
|
||||
if "length" in dataset.data.column_names:
|
||||
lengths = np.array(dataset.data.column("length"))
|
||||
elif "position_ids" in dataset.data.column_names:
|
||||
position_ids = dataset.data.column("position_ids")
|
||||
if "length" in dataset.column_names:
|
||||
lengths = np.array(dataset["length"])
|
||||
elif "position_ids" in dataset.column_names:
|
||||
position_ids = dataset["position_ids"]
|
||||
lengths = np.array([x[-1] + 1 for x in position_ids])
|
||||
else:
|
||||
input_ids = dataset.data.column("input_ids")
|
||||
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
||||
input_ids = dataset["input_ids"]
|
||||
lengths = np.array([len(seq) for seq in input_ids])
|
||||
return lengths
|
||||
|
||||
Reference in New Issue
Block a user