fix attention mask collation (#1603)
This commit is contained in:
@@ -229,9 +229,8 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
if feature == "attention_mask":
|
if feature == "attention_mask":
|
||||||
if self.multipack_attn:
|
if self.multipack_attn:
|
||||||
arrays = [
|
arrays = [
|
||||||
(i + 1) * np.array(item[feature])
|
(i + 1) * np.array(item)
|
||||||
for i, item in enumerate(features[feature])
|
for i, item in enumerate(features[feature])
|
||||||
if feature in item
|
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
arrays = [(1) * np.array(item) for item in features[feature]]
|
arrays = [(1) * np.array(item) for item in features[feature]]
|
||||||
|
|||||||
Reference in New Issue
Block a user