226 lines
7.9 KiB
Python
226 lines
7.9 KiB
Python
"""
|
|
Shared utils for the monkeypatches
|
|
"""
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from transformers.modeling_attn_mask_utils import (
|
|
_prepare_4d_causal_attention_mask,
|
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
)
|
|
from transformers.utils import is_torch_bf16_gpu_available
|
|
|
|
|
|
@torch.jit.script
|
|
def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
|
|
max_num = int(torch.max(attention_mask).item())
|
|
batch_size, _ = attention_mask.shape
|
|
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
|
|
for i in range(1, max_num + 1):
|
|
mask = attention_mask == i
|
|
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
|
|
result = counts.flatten()
|
|
nonzero_indices = torch.nonzero(result).squeeze(-1)
|
|
return result[nonzero_indices]
|
|
|
|
|
|
@torch.jit.script
|
|
def get_unpad_data(attention_mask: torch.Tensor):
|
|
device = attention_mask.device
|
|
seqlens_in_batch = get_max_seqlen_in_batch(attention_mask)
|
|
indices = torch.nonzero(attention_mask.flatten()).flatten()
|
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
|
cu_seqlens = (
|
|
F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
|
.to(device=device)
|
|
.detach()
|
|
)
|
|
return (
|
|
indices,
|
|
cu_seqlens,
|
|
max_seqlen_in_batch,
|
|
)
|
|
|
|
|
|
def get_cu_seqlens(attn_mask):
|
|
"""generate a cumulative sequence length mask for flash attention using attn mask"""
|
|
if len(attn_mask.shape) == 1:
|
|
attn_mask = attn_mask.unsqueeze(0)
|
|
|
|
device = attn_mask.device
|
|
results = []
|
|
max_seq_lens = []
|
|
|
|
for row in attn_mask:
|
|
# Exclude zeros to avoid adding their positions to the mask
|
|
t_non_zeros = row[row != 0]
|
|
# Find where the sequence number changes (including the first position)
|
|
seq_change = torch.cat(
|
|
[
|
|
torch.tensor([1], dtype=torch.int32, device=device),
|
|
t_non_zeros[1:] != t_non_zeros[:-1],
|
|
]
|
|
)
|
|
# Get the indices where the sequence changes
|
|
change_indices = torch.cat(
|
|
[
|
|
(seq_change == 1).nonzero(as_tuple=True)[0],
|
|
torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device),
|
|
]
|
|
)
|
|
# Calculate the sequence lengths
|
|
seq_lengths = change_indices[1:] - change_indices[:-1]
|
|
# Calculate the length of the final sequence or padding
|
|
final_seq_length = len(row) - change_indices[-1]
|
|
# Append the length of the final sequence or padding to seq_lengths
|
|
if final_seq_length.item():
|
|
seq_lengths = torch.cat(
|
|
[
|
|
seq_lengths,
|
|
torch.tensor(
|
|
[final_seq_length.item()], dtype=torch.int32, device=device
|
|
),
|
|
]
|
|
)
|
|
# Calculate the cumulative sequence lengths
|
|
cu_seqlens = torch.cat(
|
|
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
|
|
)
|
|
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
|
results.append(cu_seqlens)
|
|
max_seq_lens.append(max_seq_len)
|
|
|
|
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
|
|
|
|
|
def get_cu_seqlens_from_pos_ids(position_ids):
|
|
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
|
if len(position_ids.shape) == 1:
|
|
position_ids = position_ids.unsqueeze(0)
|
|
|
|
device = position_ids.device
|
|
results = []
|
|
max_seq_lens = []
|
|
|
|
for row in position_ids:
|
|
# Count the number of consecutive zeros from the right side
|
|
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
|
|
|
|
# Adjust the row to exclude padding
|
|
adjusted_row = row[:-padding_length] if padding_length else row.clone()
|
|
|
|
# Find where the position resets to 0 (indicating a new sequence)
|
|
seq_starts = torch.cat(
|
|
[
|
|
torch.tensor([True], dtype=torch.bool, device=device),
|
|
adjusted_row[1:] == 0,
|
|
]
|
|
)
|
|
# Get the indices where the sequence starts
|
|
start_indices = torch.cat(
|
|
[
|
|
torch.nonzero(seq_starts).unbind(dim=1)[0],
|
|
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
|
|
]
|
|
)
|
|
# Calculate the sequence lengths
|
|
seq_lengths = start_indices[1:] - start_indices[:-1]
|
|
# Calculate the cumulative sequence lengths
|
|
cu_seqlens = torch.cat(
|
|
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
|
|
)
|
|
# Append the padding length to the cumulative sequence lengths
|
|
if padding_length:
|
|
cu_seqlens = torch.cat(
|
|
[cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)]
|
|
)
|
|
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
|
results.append(cu_seqlens)
|
|
max_seq_lens.append(max_seq_len)
|
|
|
|
# Find the maximum value across all tensors
|
|
max_value = max(t.max() for t in results)
|
|
|
|
# Find the length of the longest tensor
|
|
max_length = max(t.size(0) for t in results)
|
|
|
|
# Pad each tensor to the same length and collect them in a list
|
|
padded_results = [
|
|
F.pad(t, (0, max_length - t.size(0)), "constant", max_value) for t in results
|
|
]
|
|
|
|
return torch.stack(padded_results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
|
|
|
|
|
def set_module_name(model, name, value):
|
|
if "." in name:
|
|
parent_name = name.rsplit(".", 1)[0]
|
|
child_name = name[len(parent_name) + 1 :]
|
|
parent = model.get_submodule(parent_name)
|
|
else:
|
|
parent_name = ""
|
|
parent = model
|
|
child_name = name
|
|
|
|
setattr(parent, child_name, value)
|
|
|
|
|
|
def mask_2d_to_4d(
|
|
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
|
):
|
|
"""
|
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
|
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
|
when they attend to each other within that sequence.
|
|
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
|
"""
|
|
bsz, src_len = mask.size()
|
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
|
|
|
mask = mask.unsqueeze(1).unsqueeze(2)
|
|
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
|
|
|
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
|
binary_mask = torch.where(
|
|
mask != 0,
|
|
torch.tensor(1, device=mask.device).to(dtype),
|
|
torch.tensor(0, device=mask.device).to(dtype),
|
|
)
|
|
|
|
# Create a block-diagonal mask.
|
|
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
|
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
|
|
|
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
|
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
|
mask.device
|
|
)
|
|
|
|
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
|
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
|
|
|
return masked_zero_one_mask
|
|
|
|
|
|
def patched_prepare_4d_causal_attention_mask(
|
|
attention_mask: Optional[torch.Tensor],
|
|
*args,
|
|
):
|
|
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
|
return _prepare_4d_causal_attention_mask(
|
|
mask_2d_to_4d(attention_mask, dtype=dtype),
|
|
*args,
|
|
)
|
|
|
|
|
|
def patched_prepare_4d_causal_attention_mask_for_sdpa(
|
|
attention_mask: Optional[torch.Tensor],
|
|
*args,
|
|
):
|
|
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
|
return _prepare_4d_causal_attention_mask_for_sdpa(
|
|
mask_2d_to_4d(attention_mask, dtype=dtype),
|
|
*args,
|
|
)
|