skip mask conversion if already 4d
This commit is contained in:
@@ -225,6 +225,9 @@ def mask_2d_to_4d(
|
||||
when they attend to each other within that sequence.
|
||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||
"""
|
||||
|
||||
if len(mask.size()) == 4:
|
||||
return mask
|
||||
bsz, src_len = int(mask.size()[0]), int(mask.size()[1])
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
|
||||
Reference in New Issue
Block a user