don't move masks to cpu
This commit is contained in:
@@ -10,9 +10,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
"""
|
"""
|
||||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||||
"""
|
"""
|
||||||
# Move the mask to the CPU
|
|
||||||
mask = mask.cpu()
|
|
||||||
|
|
||||||
bsz, src_len = mask.size()
|
bsz, src_len = mask.size()
|
||||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
|
|
||||||
@@ -27,8 +24,8 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||||
binary_mask_i = torch.where(
|
binary_mask_i = torch.where(
|
||||||
mask_i != 0,
|
mask_i != 0,
|
||||||
torch.tensor(1).to(dtype).cpu(),
|
torch.tensor(1).to(dtype),
|
||||||
torch.tensor(0).to(dtype).cpu(),
|
torch.tensor(0).to(dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a block-diagonal mask
|
# Create a block-diagonal mask
|
||||||
@@ -44,7 +41,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
|
|
||||||
return inverted_mask.masked_fill(
|
return inverted_mask.masked_fill(
|
||||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||||
).cpu()
|
)
|
||||||
|
|
||||||
|
|
||||||
def hijack_expand_mask():
|
def hijack_expand_mask():
|
||||||
|
|||||||
Reference in New Issue
Block a user