calculate sample length fixes and SFT splitting fixes (#2351)

* fix chat template splitting long samples across multiple rows

* make the preprocessing faster
This commit is contained in:
Wing Lian
2025-02-20 14:29:58 -05:00
committed by GitHub
parent 954e192f38
commit 02f45e94be
3 changed files with 13 additions and 9 deletions

View File

@@ -272,8 +272,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
dict(zip(feature_names, row))
)
for key, val in tokenized_prompt.items():
for i in range(0, len(val), self.sequence_len):
res[key].append(val[i : i + self.sequence_len])
res[key].append(val)
# If there are no examples left, return an empty dictionary
if not res:

View File

@@ -172,10 +172,11 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
)
try:
min_input_len = np.min(get_dataset_lengths(dataset))
LOG.debug(f"min_input_len: {min_input_len}")
max_input_len = np.max(get_dataset_lengths(dataset))
LOG.debug(f"max_input_len: {max_input_len}")
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
min_input_len = np.min(ds_lengths)
LOG.info(f"min_input_len: {min_input_len}")
max_input_len = np.max(ds_lengths)
LOG.info(f"max_input_len: {max_input_len}")
except AttributeError:
pass

View File

@@ -4,13 +4,17 @@ helper util to calculate dataset lengths
import numpy as np
def get_dataset_lengths(dataset):
def get_dataset_lengths(dataset, from_arrow=False):
if "length" in dataset.column_names:
lengths = np.array(dataset["length"])
elif "position_ids" in dataset.column_names:
position_ids = dataset["position_ids"]
lengths = np.array([x[-1] + 1 for x in position_ids])
else:
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
if from_arrow:
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
else:
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
return lengths