diff --git a/src/axolotl/monkeypatch/llama_expand_mask.py b/src/axolotl/monkeypatch/llama_expand_mask.py index d6de38b16..3bea39531 100644 --- a/src/axolotl/monkeypatch/llama_expand_mask.py +++ b/src/axolotl/monkeypatch/llama_expand_mask.py @@ -9,35 +9,29 @@ import torch def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + This expansion handles packed sequences so that sequences share the same attention mask integer value + when they attend to each other within that sequence. This should result in a block diagonal mask """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len - # Initialize a tensor to hold the expanded masks - expanded_masks = torch.zeros(bsz, 1, tgt_len, src_len).to(dtype) + mask = mask.unsqueeze(1).unsqueeze(2) + mask = mask.expand(bsz, 1, tgt_len, src_len) - # For each sequence in the batch - for i in range(bsz): - # Get the mask for this sequence - mask_i = mask[i].unsqueeze(0) + # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one + binary_mask = torch.where( + mask != 0, + torch.tensor(1).to(dtype), + torch.tensor(0).to(dtype), + ) - # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one - binary_mask_i = torch.where( - mask_i != 0, - torch.tensor(1).to(dtype), - torch.tensor(0).to(dtype), - ) + # Create a block-diagonal mask. + # we multiply by the binary mask so that 0's in the original mask are correctly excluded + zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask - # Create a block-diagonal mask - zero_one_mask_i = torch.eq(mask_i, mask_i.t()).int() * binary_mask_i - - # Expand the mask - expanded_mask_i = zero_one_mask_i.unsqueeze(0).expand(1, 1, tgt_len, src_len) - - # Store the expanded mask - expanded_masks[i] = expanded_mask_i - - inverted_mask = 1.0 - expanded_masks + # Expand the mask to the correct dimensions for the current batch index + expanded_mask = zero_one_mask.expand(bsz, 1, tgt_len, src_len) + inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill( inverted_mask.to(torch.bool), torch.finfo(dtype).min