revert mask expand

This commit is contained in:
Sunny Liu
2025-01-23 11:20:38 -05:00
parent 85752cdfc9
commit e8b2789086

View File

@@ -228,8 +228,8 @@ def mask_2d_to_4d(
bsz, src_len = int(mask.size()[0]), int(mask.size()[1])
tgt_len = tgt_len if tgt_len is not None else src_len
# mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask[None, None, :, :].expand(bsz, 1, tgt_len, src_len)
mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(bsz, 1, tgt_len, src_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(