diff --git a/src/axolotl/monkeypatch/flex_attn.py b/src/axolotl/monkeypatch/flex_attn.py index 97db4d414..a157906d5 100644 --- a/src/axolotl/monkeypatch/flex_attn.py +++ b/src/axolotl/monkeypatch/flex_attn.py @@ -10,6 +10,45 @@ from torch.nn.attention.flex_attention import ( _MaskType = Union[torch.Tensor, BlockMask] + +def create_block_causal_mask(seq_lens: list[torch.Tensor]) -> torch.Tensor: + """ + Given a batch tensor of seq lens defining the lengths of samples in each pack, + Construct a 2D block causal mask for each pack in the batch. For example, if + a single sample's seq_lens is [3, 2, 1], the mask would be:: + + mask = [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1], + ] + + Args: + seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch, + shape (batch_size, n), where n is the max number of sequences in a pack and can vary + across packs. + + + Returns: + Tensor: Block causal mask of shape (batch_size, max_seq_len, max_seq_len). + """ + batch_block_attn_masks = [] + batch_size = len(seq_lens) + for sample_idx in range(batch_size): + block_attn_masks = [ + torch.tril( + torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device) + ) + for i, seq_len in enumerate(seq_lens[sample_idx]) + ] + + batch_block_attn_masks.append(torch.block_diag(*block_attn_masks)) + return torch.stack(batch_block_attn_masks) + + def _get_document_ids_from_seq_lens( seq_lens: list[torch.Tensor], ) -> torch.Tensor: diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 6508602cf..d449b38ad 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -11,7 +11,7 @@ from transformers import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy from axolotl.monkeypatch.utils import get_seqlens_from_pos_ids -from axolotl.monkeypatch.flex_attn import packed_block_causal_mask +from axolotl.monkeypatch.flex_attn import create_block_causal_mask, packed_block_causal_mask @dataclass @@ -177,7 +177,8 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): out = super().__call__(out_features, return_tensors=return_tensors) collated_seq_lens = get_seqlens_from_pos_ids(out["position_ids"]) - out["attention_mask"] = packed_block_causal_mask(collated_seq_lens) + # out["attention_mask"] = packed_block_causal_mask(collated_seq_lens) + out["attention_mask"] = create_block_causal_mask(collated_seq_lens) return out