diff --git a/src/axolotl/monkeypatch/llama_expand_mask.py b/src/axolotl/monkeypatch/llama_expand_mask.py index e79e7afc3..d6de38b16 100644 --- a/src/axolotl/monkeypatch/llama_expand_mask.py +++ b/src/axolotl/monkeypatch/llama_expand_mask.py @@ -10,9 +10,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ - # Move the mask to the CPU - mask = mask.cpu() - bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len @@ -27,8 +24,8 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one binary_mask_i = torch.where( mask_i != 0, - torch.tensor(1).to(dtype).cpu(), - torch.tensor(0).to(dtype).cpu(), + torch.tensor(1).to(dtype), + torch.tensor(0).to(dtype), ) # Create a block-diagonal mask @@ -44,7 +41,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] return inverted_mask.masked_fill( inverted_mask.to(torch.bool), torch.finfo(dtype).min - ).cpu() + ) def hijack_expand_mask():