fix for flash attn w mistral w/o sammple packing (#648)

This commit is contained in:
Wing Lian
2023-09-28 10:57:37 -04:00
committed by GitHub
parent b88f51512a
commit b2edaaeff6

View File

@@ -2,13 +2,17 @@
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import logging import logging
import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import transformers import transformers
from einops import rearrange from einops import rearrange
from torch import nn from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
flash_attn_kvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mistral.modeling_mistral import ( from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer, MistralDecoderLayer as OriginalMistralDecoderLayer,
@@ -17,16 +21,6 @@ from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, r
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
flash_attn_varlen_qkvpacked_func,
)
except ImportError:
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
)
LOG = logging.getLogger("axolotl.monkeypatch.mistral") LOG = logging.getLogger("axolotl.monkeypatch.mistral")
@@ -108,6 +102,15 @@ def flashattn_forward(
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
if self.training:
# during training q,k,v always have same seqlen
assert key_states.shape == query_states.shape
is_causal = True
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing # special handling using sample packing
qkv = torch.stack( qkv = torch.stack(
@@ -120,46 +123,84 @@ def flashattn_forward(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
attn_output = output elif query_states.shape == key_states.shape:
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): query_states = query_states.transpose(1, 2)
raise ValueError( key_states = key_states.transpose(1, 2)
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" value_states = value_states.transpose(1, 2)
f" {attn_output.size()}" qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
) query_states,
attn_output = rearrange(attn_output, "b s h d -> b s (h d)") key_states,
attn_weights = None value_states,
qkvpacked=True,
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
)
output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad,
cu_seqlens_q,
max_seqlen_q,
0.0,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
else: else:
attn_weights = torch.matmul( query_states = query_states.transpose(1, 2)
query_states, key_states.transpose(2, 3) key_states = key_states.transpose(1, 2)
) / math.sqrt(self.head_dim) value_states = value_states.transpose(1, 2)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): if attention_mask is None or attention_mask.all().item():
raise ValueError( output = flash_attn_kvpacked_func(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" query_states,
f" {attn_weights.size()}" torch.stack([key_states, value_states], 2),
causal=is_causal,
) )
else:
if attention_mask is not None: ( # pylint: disable=unbalanced-tuple-unpacking
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): q_unpad,
raise ValueError( kv_unpad,
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" cu_seqlens_q,
) cu_seqlens_k,
max_seqlen_q,
attn_weights = attn_weights + attention_mask max_seqlen_k,
_,
# upcast attention to fp32 _,
attn_weights = nn.functional.softmax( output_pad_fn,
attn_weights, dim=-1, dtype=torch.float32 ) = generate_qkv(
).to(query_states.dtype) query_states,
attn_output = torch.matmul(attn_weights, value_states) key_states,
value_states,
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): kvpacked=True,
raise ValueError( key_padding_mask=attention_mask,
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" query_padding_mask=attention_mask[:, -query_states.size(1) :]
f" {attn_output.size()}" if attention_mask is not None
else None,
) )
if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype)
output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = output
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
@@ -169,6 +210,105 @@ def flashattn_forward(
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
def generate_qkv(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False,
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
q, query_padding_mask
)
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q_unpad.device,
)
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=k_unpad.device,
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
if kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
return (
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
kv,
output_pad_fn,
)
return (
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
)
def mistral_model_forward( def mistral_model_forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,