optimized expand mask fn
This commit is contained in:
@@ -9,35 +9,29 @@ import torch
|
|||||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||||
"""
|
"""
|
||||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||||
|
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
||||||
|
when they attend to each other within that sequence. This should result in a block diagonal mask
|
||||||
"""
|
"""
|
||||||
bsz, src_len = mask.size()
|
bsz, src_len = mask.size()
|
||||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
|
|
||||||
# Initialize a tensor to hold the expanded masks
|
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||||
expanded_masks = torch.zeros(bsz, 1, tgt_len, src_len).to(dtype)
|
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
||||||
|
|
||||||
# For each sequence in the batch
|
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||||
for i in range(bsz):
|
binary_mask = torch.where(
|
||||||
# Get the mask for this sequence
|
mask != 0,
|
||||||
mask_i = mask[i].unsqueeze(0)
|
torch.tensor(1).to(dtype),
|
||||||
|
torch.tensor(0).to(dtype),
|
||||||
|
)
|
||||||
|
|
||||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
# Create a block-diagonal mask.
|
||||||
binary_mask_i = torch.where(
|
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
||||||
mask_i != 0,
|
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
||||||
torch.tensor(1).to(dtype),
|
|
||||||
torch.tensor(0).to(dtype),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a block-diagonal mask
|
# Expand the mask to the correct dimensions for the current batch index
|
||||||
zero_one_mask_i = torch.eq(mask_i, mask_i.t()).int() * binary_mask_i
|
expanded_mask = zero_one_mask.expand(bsz, 1, tgt_len, src_len)
|
||||||
|
inverted_mask = 1.0 - expanded_mask
|
||||||
# Expand the mask
|
|
||||||
expanded_mask_i = zero_one_mask_i.unsqueeze(0).expand(1, 1, tgt_len, src_len)
|
|
||||||
|
|
||||||
# Store the expanded mask
|
|
||||||
expanded_masks[i] = expanded_mask_i
|
|
||||||
|
|
||||||
inverted_mask = 1.0 - expanded_masks
|
|
||||||
|
|
||||||
return inverted_mask.masked_fill(
|
return inverted_mask.masked_fill(
|
||||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||||
|
|||||||
Reference in New Issue
Block a user