From 0b01da071381b08475814d8ebf76cbf3e3c0f1c3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 3 Aug 2023 16:12:04 -0400 Subject: [PATCH] properly calculate max len --- src/axolotl/monkeypatch/llama_attn_hijack_flash.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index c4f1fe947..74f5d0f9e 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -43,7 +43,9 @@ def get_cu_seqlens(attn_mask): [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)] ) - return cu_seqlens.to(dtype=torch.int32) + max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + + return cu_seqlens.to(dtype=torch.int32), max_seq_len def forward( @@ -119,8 +121,7 @@ def forward( output = rearrange(output, "(b s) ... -> b s ...", b=bsz) else: qkv = rearrange(qkv, "b s ... -> (b s) ...") - max_s = q_len - cu_q_lens = get_cu_seqlens(key_padding_mask) + cu_q_lens, max_s = get_cu_seqlens(key_padding_mask) output = flash_attn_varlen_qkvpacked_func( qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True