skip mask conversion if already 4d

This commit is contained in:
Sunny Liu
2025-01-23 14:01:53 -05:00
parent e8b2789086
commit 555aa5772a

View File

@@ -225,6 +225,9 @@ def mask_2d_to_4d(
when they attend to each other within that sequence. when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking. This expansion transforms the mask to lower triangular form to prevent future peeking.
""" """
if len(mask.size()) == 4:
return mask
bsz, src_len = int(mask.size()[0]), int(mask.size()[1]) bsz, src_len = int(mask.size()[0]), int(mask.size()[1])
tgt_len = tgt_len if tgt_len is not None else src_len tgt_len = tgt_len if tgt_len is not None else src_len