mask expansion

This commit is contained in:
Sunny Liu
2025-01-22 21:29:52 -05:00
parent bb9bea3110
commit 8b3eec7f6e

View File

@@ -225,7 +225,7 @@ def mask_2d_to_4d(
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()
bsz, src_len = int(mask.size()[0]), int(mask.size()[1])
tgt_len = tgt_len if tgt_len is not None else src_len
# mask = mask.unsqueeze(1).unsqueeze(2)