calculate sample length fixes and SFT splitting fixes (#2351)
* fix chat template splitting long samples across multiple rows * make the preprocessing faster
This commit is contained in:
@@ -272,8 +272,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
dict(zip(feature_names, row))
|
dict(zip(feature_names, row))
|
||||||
)
|
)
|
||||||
for key, val in tokenized_prompt.items():
|
for key, val in tokenized_prompt.items():
|
||||||
for i in range(0, len(val), self.sequence_len):
|
res[key].append(val)
|
||||||
res[key].append(val[i : i + self.sequence_len])
|
|
||||||
|
|
||||||
# If there are no examples left, return an empty dictionary
|
# If there are no examples left, return an empty dictionary
|
||||||
if not res:
|
if not res:
|
||||||
|
|||||||
@@ -172,10 +172,11 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
min_input_len = np.min(get_dataset_lengths(dataset))
|
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
||||||
LOG.debug(f"min_input_len: {min_input_len}")
|
min_input_len = np.min(ds_lengths)
|
||||||
max_input_len = np.max(get_dataset_lengths(dataset))
|
LOG.info(f"min_input_len: {min_input_len}")
|
||||||
LOG.debug(f"max_input_len: {max_input_len}")
|
max_input_len = np.max(ds_lengths)
|
||||||
|
LOG.info(f"max_input_len: {max_input_len}")
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -4,13 +4,17 @@ helper util to calculate dataset lengths
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_lengths(dataset):
|
def get_dataset_lengths(dataset, from_arrow=False):
|
||||||
if "length" in dataset.column_names:
|
if "length" in dataset.column_names:
|
||||||
lengths = np.array(dataset["length"])
|
lengths = np.array(dataset["length"])
|
||||||
elif "position_ids" in dataset.column_names:
|
elif "position_ids" in dataset.column_names:
|
||||||
position_ids = dataset["position_ids"]
|
position_ids = dataset["position_ids"]
|
||||||
lengths = np.array([x[-1] + 1 for x in position_ids])
|
lengths = np.array([x[-1] + 1 for x in position_ids])
|
||||||
else:
|
else:
|
||||||
input_ids = dataset["input_ids"]
|
if from_arrow:
|
||||||
lengths = np.array([len(seq) for seq in input_ids])
|
input_ids = dataset.data.column("input_ids")
|
||||||
|
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
||||||
|
else:
|
||||||
|
input_ids = dataset["input_ids"]
|
||||||
|
lengths = np.array([len(seq) for seq in input_ids])
|
||||||
return lengths
|
return lengths
|
||||||
|
|||||||
Reference in New Issue
Block a user