wip flex block mask creation

This commit is contained in:
bursteratom
2025-01-29 00:25:25 -05:00
parent b31796a681
commit ba88bc7840

View File

@@ -169,19 +169,26 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if feature == "length":
continue
if feature == "attention_mask":
seq_len_list = [
get_seqlens_from_pos_ids(item["position_ids"])
for item in features_
arrays = [
(i + 1) * np.array(item[feature])
for i, item in enumerate(features_)
if feature in item
]
out_features[i][feature] = np.concatenate(seq_len_list)
out_features[i][feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature]) for item in features_ if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
return super().__call__(out_features, return_tensors=return_tensors)
out = super().__call__(out_features, return_tensors=return_tensors)
collated_seq_lens = get_seqlens_from_pos_ids(out["position_ids"])
doc_mask =
out["attention_mask"]
return out
@dataclass