stuff
This commit is contained in:
@@ -12,7 +12,9 @@ from torch.nn.attention.flex_attention import (
|
|||||||
_MaskType = Union[torch.Tensor, BlockMask]
|
_MaskType = Union[torch.Tensor, BlockMask]
|
||||||
|
|
||||||
|
|
||||||
def create_block_causal_mask(seq_lens: list[torch.Tensor], max_seq_len: int) -> torch.Tensor:
|
def create_block_causal_mask(
|
||||||
|
seq_lens: list[torch.Tensor], max_seq_len: int
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Given a batch tensor of seq lens defining the lengths of samples in each pack,
|
Given a batch tensor of seq lens defining the lengths of samples in each pack,
|
||||||
Construct a 2D block causal mask for each pack in the batch. For example, if
|
Construct a 2D block causal mask for each pack in the batch. For example, if
|
||||||
@@ -43,15 +45,19 @@ def create_block_causal_mask(seq_lens: list[torch.Tensor], max_seq_len: int) ->
|
|||||||
torch.tril(
|
torch.tril(
|
||||||
torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device)
|
torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device)
|
||||||
)
|
)
|
||||||
for seq_len in enumerate(seq_lens[sample_idx])
|
for seq_len in seq_lens[sample_idx]
|
||||||
]
|
]
|
||||||
|
|
||||||
residue_len = max_seq_len - torch.sum(seq_lens[sample_idx])
|
residue_len = max_seq_len - torch.sum(seq_lens[sample_idx])
|
||||||
block_attn_masks.append(torch.zeros(residue_len, residue_len, dtype=torch.bool, device=seq_lens[0][0].device))
|
block_attn_masks.append(
|
||||||
|
torch.zeros(
|
||||||
|
residue_len, residue_len, dtype=torch.bool, device=seq_lens[0][0].device
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))
|
batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))
|
||||||
|
|
||||||
return torch.stack(batch_block_attn_masks)[:,None,:,:]
|
return torch.stack(batch_block_attn_masks)[:, None, :, :]
|
||||||
|
|
||||||
|
|
||||||
def _get_document_ids_from_seq_lens(
|
def _get_document_ids_from_seq_lens(
|
||||||
|
|||||||
Reference in New Issue
Block a user