diff --git a/src/axolotl/monkeypatch/flex_attn.py b/src/axolotl/monkeypatch/flex_attn.py index f764e7c16..76b7d3bc8 100644 --- a/src/axolotl/monkeypatch/flex_attn.py +++ b/src/axolotl/monkeypatch/flex_attn.py @@ -12,7 +12,9 @@ from torch.nn.attention.flex_attention import ( _MaskType = Union[torch.Tensor, BlockMask] -def create_block_causal_mask(seq_lens: list[torch.Tensor], max_seq_len: int) -> torch.Tensor: +def create_block_causal_mask( + seq_lens: list[torch.Tensor], max_seq_len: int +) -> torch.Tensor: """ Given a batch tensor of seq lens defining the lengths of samples in each pack, Construct a 2D block causal mask for each pack in the batch. For example, if @@ -43,15 +45,19 @@ def create_block_causal_mask(seq_lens: list[torch.Tensor], max_seq_len: int) -> torch.tril( torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device) ) - for seq_len in enumerate(seq_lens[sample_idx]) + for seq_len in seq_lens[sample_idx] ] - + residue_len = max_seq_len - torch.sum(seq_lens[sample_idx]) - block_attn_masks.append(torch.zeros(residue_len, residue_len, dtype=torch.bool, device=seq_lens[0][0].device)) - + block_attn_masks.append( + torch.zeros( + residue_len, residue_len, dtype=torch.bool, device=seq_lens[0][0].device + ) + ) + batch_block_attn_masks.append(torch.block_diag(*block_attn_masks)) - return torch.stack(batch_block_attn_masks)[:,None,:,:] + return torch.stack(batch_block_attn_masks)[:, None, :, :] def _get_document_ids_from_seq_lens(