don't move masks to cpu
This commit is contained in:
@@ -10,9 +10,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
# Move the mask to the CPU
|
||||
mask = mask.cpu()
|
||||
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
@@ -27,8 +24,8 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||
binary_mask_i = torch.where(
|
||||
mask_i != 0,
|
||||
torch.tensor(1).to(dtype).cpu(),
|
||||
torch.tensor(0).to(dtype).cpu(),
|
||||
torch.tensor(1).to(dtype),
|
||||
torch.tensor(0).to(dtype),
|
||||
)
|
||||
|
||||
# Create a block-diagonal mask
|
||||
@@ -44,7 +41,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
).cpu()
|
||||
)
|
||||
|
||||
|
||||
def hijack_expand_mask():
|
||||
|
||||
Reference in New Issue
Block a user