properly calculate max len
This commit is contained in:
@@ -43,7 +43,9 @@ def get_cu_seqlens(attn_mask):
|
|||||||
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
|
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
|
||||||
)
|
)
|
||||||
|
|
||||||
return cu_seqlens.to(dtype=torch.int32)
|
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||||
|
|
||||||
|
return cu_seqlens.to(dtype=torch.int32), max_seq_len
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -119,8 +121,7 @@ def forward(
|
|||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
else:
|
else:
|
||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
max_s = q_len
|
cu_q_lens, max_s = get_cu_seqlens(key_padding_mask)
|
||||||
cu_q_lens = get_cu_seqlens(key_padding_mask)
|
|
||||||
|
|
||||||
output = flash_attn_varlen_qkvpacked_func(
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||||
|
|||||||
Reference in New Issue
Block a user