mask expansion
This commit is contained in:
@@ -229,7 +229,7 @@ def mask_2d_to_4d(
|
|||||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
|
|
||||||
# mask = mask.unsqueeze(1).unsqueeze(2)
|
# mask = mask.unsqueeze(1).unsqueeze(2)
|
||||||
mask = mask[:, None, :].expand(bsz, 1, tgt_len, src_len)
|
mask = mask[None, None, :, :].expand(bsz, 1, tgt_len, src_len)
|
||||||
|
|
||||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||||
binary_mask = torch.where(
|
binary_mask = torch.where(
|
||||||
|
|||||||
Reference in New Issue
Block a user