optimized expand mask fn

This commit is contained in:
Wing Lian
2023-07-24 17:11:02 -04:00
parent 7d7b5ebd71
commit 32fed7039d

View File

@@ -9,35 +9,29 @@ import torch
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence. This should result in a block diagonal mask
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
# Initialize a tensor to hold the expanded masks
expanded_masks = torch.zeros(bsz, 1, tgt_len, src_len).to(dtype)
mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(bsz, 1, tgt_len, src_len)
# For each sequence in the batch
for i in range(bsz):
# Get the mask for this sequence
mask_i = mask[i].unsqueeze(0)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1).to(dtype),
torch.tensor(0).to(dtype),
)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask_i = torch.where(
mask_i != 0,
torch.tensor(1).to(dtype),
torch.tensor(0).to(dtype),
)
# Create a block-diagonal mask.
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
# Create a block-diagonal mask
zero_one_mask_i = torch.eq(mask_i, mask_i.t()).int() * binary_mask_i
# Expand the mask
expanded_mask_i = zero_one_mask_i.unsqueeze(0).expand(1, 1, tgt_len, src_len)
# Store the expanded mask
expanded_masks[i] = expanded_mask_i
inverted_mask = 1.0 - expanded_masks
# Expand the mask to the correct dimensions for the current batch index
expanded_mask = zero_one_mask.expand(bsz, 1, tgt_len, src_len)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min