mask expansion
This commit is contained in:
@@ -225,7 +225,7 @@ def mask_2d_to_4d(
|
|||||||
when they attend to each other within that sequence.
|
when they attend to each other within that sequence.
|
||||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||||
"""
|
"""
|
||||||
bsz, src_len = mask.size()
|
bsz, src_len = int(mask.size()[0]), int(mask.size()[1])
|
||||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
|
|
||||||
# mask = mask.unsqueeze(1).unsqueeze(2)
|
# mask = mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
|||||||
Reference in New Issue
Block a user