remove padding self attention
This commit is contained in:
@@ -12,7 +12,7 @@ from torch.nn.attention.flex_attention import (
|
|||||||
_MaskType = Union[torch.Tensor, BlockMask]
|
_MaskType = Union[torch.Tensor, BlockMask]
|
||||||
|
|
||||||
|
|
||||||
def create_block_causal_mask(seq_lens: list[torch.Tensor]) -> torch.Tensor:
|
def create_block_causal_mask(seq_lens: list[torch.Tensor], max_seq_len: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Given a batch tensor of seq lens defining the lengths of samples in each pack,
|
Given a batch tensor of seq lens defining the lengths of samples in each pack,
|
||||||
Construct a 2D block causal mask for each pack in the batch. For example, if
|
Construct a 2D block causal mask for each pack in the batch. For example, if
|
||||||
@@ -43,10 +43,14 @@ def create_block_causal_mask(seq_lens: list[torch.Tensor]) -> torch.Tensor:
|
|||||||
torch.tril(
|
torch.tril(
|
||||||
torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device)
|
torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device)
|
||||||
)
|
)
|
||||||
for i, seq_len in enumerate(seq_lens[sample_idx])
|
for seq_len in enumerate(seq_lens[sample_idx])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
residue_len = max_seq_len - torch.sum(seq_lens[sample_idx])
|
||||||
|
block_attn_masks.append(torch.zeros(residue_len, residue_len, dtype=torch.bool, device=seq_lens[0][0].device))
|
||||||
|
|
||||||
batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))
|
batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))
|
||||||
|
|
||||||
return torch.stack(batch_block_attn_masks)[:,None,:,:]
|
return torch.stack(batch_block_attn_masks)[:,None,:,:]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ def get_seqlens_from_pos_ids(position_ids):
|
|||||||
"""generate a sequence length set using pos ids for doc mask creation in flex attention"""
|
"""generate a sequence length set using pos ids for doc mask creation in flex attention"""
|
||||||
if len(position_ids.shape) == 1:
|
if len(position_ids.shape) == 1:
|
||||||
position_ids = position_ids.unsqueeze(0)
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
max_seq_len = position_ids.shape[1]
|
||||||
|
|
||||||
device = position_ids.device
|
device = position_ids.device
|
||||||
results = []
|
results = []
|
||||||
@@ -126,22 +127,10 @@ def get_seqlens_from_pos_ids(position_ids):
|
|||||||
)
|
)
|
||||||
# Calculate the sequence lengths
|
# Calculate the sequence lengths
|
||||||
seq_lengths = start_indices[1:] - start_indices[:-1]
|
seq_lengths = start_indices[1:] - start_indices[:-1]
|
||||||
# Append the padding length to the sequence lengths
|
|
||||||
if padding_length:
|
|
||||||
seq_lengths = torch.cat(
|
|
||||||
[
|
|
||||||
seq_lengths,
|
|
||||||
torch.tensor(
|
|
||||||
[len(row) - torch.sum(seq_lengths)],
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
results.append(seq_lengths)
|
results.append(seq_lengths)
|
||||||
|
|
||||||
return results
|
return results , max_seq_len
|
||||||
|
|
||||||
|
|
||||||
def get_cu_seqlens_from_pos_ids(position_ids):
|
def get_cu_seqlens_from_pos_ids(position_ids):
|
||||||
|
|||||||
@@ -179,9 +179,9 @@ class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
out = super().__call__(out_features, return_tensors=return_tensors)
|
out = super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
collated_seq_lens = get_seqlens_from_pos_ids(out["position_ids"])
|
collated_seq_lens, max_seq_len = get_seqlens_from_pos_ids(out["position_ids"])
|
||||||
# out["attention_mask"] = packed_block_causal_mask(collated_seq_lens)
|
# out["attention_mask"] = packed_block_causal_mask(collated_seq_lens)
|
||||||
out["attention_mask"] = create_block_causal_mask(collated_seq_lens)
|
out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user