mask expansion

This commit is contained in:
Sunny Liu
2025-01-22 21:27:25 -05:00
parent 0dd18a3681
commit bb9bea3110

View File

@@ -228,8 +228,8 @@ def mask_2d_to_4d(
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(bsz, 1, tgt_len, src_len)
# mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(