fix attention mask collation (#1603)

This commit is contained in:
Wing Lian
2024-05-14 08:17:30 -04:00
committed by GitHub
parent 5d97e65f95
commit 02982733ec

View File

@@ -229,9 +229,8 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if feature == "attention_mask":
if self.multipack_attn:
arrays = [
(i + 1) * np.array(item[feature])
(i + 1) * np.array(item)
for i, item in enumerate(features[feature])
if feature in item
]
else:
arrays = [(1) * np.array(item) for item in features[feature]]