don't move masks to cpu

This commit is contained in:
Wing Lian
2023-07-17 11:08:43 -04:00
parent ef9bf7ad73
commit ffd96839cf

View File

@@ -10,9 +10,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
# Move the mask to the CPU
mask = mask.cpu()
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
@@ -27,8 +24,8 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask_i = torch.where(
mask_i != 0,
torch.tensor(1).to(dtype).cpu(),
torch.tensor(0).to(dtype).cpu(),
torch.tensor(1).to(dtype),
torch.tensor(0).to(dtype),
)
# Create a block-diagonal mask
@@ -44,7 +41,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
).cpu()
)
def hijack_expand_mask():