remove padding self attention

This commit is contained in:
Sunny Liu
2025-02-01 22:47:10 -05:00
parent 48c3c47071
commit 3f4fd3c1eb
3 changed files with 11 additions and 18 deletions

View File

@@ -12,7 +12,7 @@ from torch.nn.attention.flex_attention import (
_MaskType = Union[torch.Tensor, BlockMask]
def create_block_causal_mask(seq_lens: list[torch.Tensor]) -> torch.Tensor:
def create_block_causal_mask(seq_lens: list[torch.Tensor], max_seq_len: int) -> torch.Tensor:
"""
Given a batch tensor of seq lens defining the lengths of samples in each pack,
Construct a 2D block causal mask for each pack in the batch. For example, if
@@ -43,10 +43,14 @@ def create_block_causal_mask(seq_lens: list[torch.Tensor]) -> torch.Tensor:
torch.tril(
torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device)
)
for i, seq_len in enumerate(seq_lens[sample_idx])
for seq_len in enumerate(seq_lens[sample_idx])
]
residue_len = max_seq_len - torch.sum(seq_lens[sample_idx])
block_attn_masks.append(torch.zeros(residue_len, residue_len, dtype=torch.bool, device=seq_lens[0][0].device))
batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))
return torch.stack(batch_block_attn_masks)[:,None,:,:]

View File

@@ -99,6 +99,7 @@ def get_seqlens_from_pos_ids(position_ids):
"""generate a sequence length set using pos ids for doc mask creation in flex attention"""
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)
max_seq_len = position_ids.shape[1]
device = position_ids.device
results = []
@@ -126,22 +127,10 @@ def get_seqlens_from_pos_ids(position_ids):
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Append the padding length to the sequence lengths
if padding_length:
seq_lengths = torch.cat(
[
seq_lengths,
torch.tensor(
[len(row) - torch.sum(seq_lengths)],
dtype=torch.int32,
device=device,
),
]
)
results.append(seq_lengths)
return results
return results , max_seq_len
def get_cu_seqlens_from_pos_ids(position_ids):

View File

@@ -179,9 +179,9 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
out_features[i][feature] = np.concatenate(arrays)
out = super().__call__(out_features, return_tensors=return_tensors)
collated_seq_lens = get_seqlens_from_pos_ids(out["position_ids"])
collated_seq_lens, max_seq_len = get_seqlens_from_pos_ids(out["position_ids"])
# out["attention_mask"] = packed_block_causal_mask(collated_seq_lens)
out["attention_mask"] = create_block_causal_mask(collated_seq_lens)
out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
return out