skip mask conversion if already 4d

This commit is contained in:
Sunny Liu
2025-01-23 14:01:53 -05:00
parent e8b2789086
commit 555aa5772a

View File

@@ -225,6 +225,9 @@ def mask_2d_to_4d(
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
if len(mask.size()) == 4:
return mask
bsz, src_len = int(mask.size()[0]), int(mask.size()[1])
tgt_len = tgt_len if tgt_len is not None else src_len