mask expansion

This commit is contained in:
Sunny Liu
2025-01-22 21:29:52 -05:00
parent bb9bea3110
commit 8b3eec7f6e

View File

@@ -225,7 +225,7 @@ def mask_2d_to_4d(
when they attend to each other within that sequence. when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking. This expansion transforms the mask to lower triangular form to prevent future peeking.
""" """
bsz, src_len = mask.size() bsz, src_len = int(mask.size()[0]), int(mask.size()[1])
tgt_len = tgt_len if tgt_len is not None else src_len tgt_len = tgt_len if tgt_len is not None else src_len
# mask = mask.unsqueeze(1).unsqueeze(2) # mask = mask.unsqueeze(1).unsqueeze(2)