mask expansion

This commit is contained in:
Sunny Liu
2025-01-22 21:31:42 -05:00
parent 8b3eec7f6e
commit f2f23c8041

View File

@@ -229,7 +229,7 @@ def mask_2d_to_4d(
tgt_len = tgt_len if tgt_len is not None else src_len
# mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len)
mask = mask[:, None, :].expand(bsz, 1, tgt_len, src_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(