fix for flash attn w mistral w/o sammple packing (#648)
This commit is contained in:
@@ -2,13 +2,17 @@
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import nn
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
|
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||||
|
flash_attn_kvpacked_func,
|
||||||
|
flash_attn_varlen_kvpacked_func,
|
||||||
|
flash_attn_varlen_qkvpacked_func,
|
||||||
|
)
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||||
@@ -17,16 +21,6 @@ from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, r
|
|||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
|
||||||
try:
|
|
||||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
|
||||||
flash_attn_varlen_qkvpacked_func,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
from flash_attn.flash_attn_interface import (
|
|
||||||
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
||||||
|
|
||||||
|
|
||||||
@@ -108,6 +102,15 @@ def flashattn_forward(
|
|||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
# during training q,k,v always have same seqlen
|
||||||
|
assert key_states.shape == query_states.shape
|
||||||
|
is_causal = True
|
||||||
|
else:
|
||||||
|
# turn off FA causal mask after first inference autoregressive iteration
|
||||||
|
# only on first autoregressive step q,k,v have same seqlen
|
||||||
|
is_causal = key_states.shape == query_states.shape
|
||||||
|
|
||||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||||
# special handling using sample packing
|
# special handling using sample packing
|
||||||
qkv = torch.stack(
|
qkv = torch.stack(
|
||||||
@@ -120,46 +123,84 @@ def flashattn_forward(
|
|||||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
||||||
)
|
)
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
attn_output = output
|
elif query_states.shape == key_states.shape:
|
||||||
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
query_states = query_states.transpose(1, 2)
|
||||||
raise ValueError(
|
key_states = key_states.transpose(1, 2)
|
||||||
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
value_states = value_states.transpose(1, 2)
|
||||||
f" {attn_output.size()}"
|
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
||||||
)
|
query_states,
|
||||||
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
key_states,
|
||||||
attn_weights = None
|
value_states,
|
||||||
|
qkvpacked=True,
|
||||||
|
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||||
|
# the attention_mask should be the same as the key_padding_mask
|
||||||
|
key_padding_mask=attention_mask,
|
||||||
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
|
if attention_mask is not None
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
max_seqlen_q,
|
||||||
|
0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
output = output_pad_fn(output_unpad)
|
||||||
else:
|
else:
|
||||||
attn_weights = torch.matmul(
|
query_states = query_states.transpose(1, 2)
|
||||||
query_states, key_states.transpose(2, 3)
|
key_states = key_states.transpose(1, 2)
|
||||||
) / math.sqrt(self.head_dim)
|
value_states = value_states.transpose(1, 2)
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
if attention_mask is None or attention_mask.all().item():
|
||||||
raise ValueError(
|
output = flash_attn_kvpacked_func(
|
||||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
query_states,
|
||||||
f" {attn_weights.size()}"
|
torch.stack([key_states, value_states], 2),
|
||||||
|
causal=is_causal,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
if attention_mask is not None:
|
( # pylint: disable=unbalanced-tuple-unpacking
|
||||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
q_unpad,
|
||||||
raise ValueError(
|
kv_unpad,
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
cu_seqlens_q,
|
||||||
)
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
attn_weights = attn_weights + attention_mask
|
max_seqlen_k,
|
||||||
|
_,
|
||||||
# upcast attention to fp32
|
_,
|
||||||
attn_weights = nn.functional.softmax(
|
output_pad_fn,
|
||||||
attn_weights, dim=-1, dtype=torch.float32
|
) = generate_qkv(
|
||||||
).to(query_states.dtype)
|
query_states,
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
key_states,
|
||||||
|
value_states,
|
||||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
kvpacked=True,
|
||||||
raise ValueError(
|
key_padding_mask=attention_mask,
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
f" {attn_output.size()}"
|
if attention_mask is not None
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
|
if q_unpad.dtype != kv_unpad.dtype:
|
||||||
|
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||||
|
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
output = output_pad_fn(output_unpad)
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = output
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
@@ -169,6 +210,105 @@ def flashattn_forward(
|
|||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
||||||
|
def generate_qkv(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
query_padding_mask=None,
|
||||||
|
key_padding_mask=None,
|
||||||
|
kvpacked=False,
|
||||||
|
qkvpacked=False,
|
||||||
|
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
q: (batch_size, seqlen_q, nheads, d)
|
||||||
|
k: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
v: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
query_padding_mask: (batch_size, seqlen), bool
|
||||||
|
key_padding_mask: (batch_size, seqlen), bool
|
||||||
|
"""
|
||||||
|
assert not (kvpacked and qkvpacked)
|
||||||
|
batch_size, seqlen_q, nheads, d = q.shape
|
||||||
|
_, seqlen_k, nheads_k, _ = k.shape
|
||||||
|
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
|
||||||
|
if query_padding_mask is not None:
|
||||||
|
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
||||||
|
q, query_padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
|
||||||
|
output_unpad, indices_q, batch_size, seqlen_q
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_q,
|
||||||
|
step=seqlen_q,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=q_unpad.device,
|
||||||
|
)
|
||||||
|
max_seqlen_q = seqlen_q
|
||||||
|
|
||||||
|
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
|
||||||
|
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||||
|
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
||||||
|
else:
|
||||||
|
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||||
|
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_k = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_k,
|
||||||
|
step=seqlen_k,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=k_unpad.device,
|
||||||
|
)
|
||||||
|
max_seqlen_k = seqlen_k
|
||||||
|
|
||||||
|
if qkvpacked:
|
||||||
|
assert nheads == nheads_k
|
||||||
|
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||||
|
qkv = torch.stack([q, k, v], dim=2)
|
||||||
|
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
||||||
|
|
||||||
|
if kvpacked:
|
||||||
|
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||||
|
kv = torch.stack([k, v], dim=2)
|
||||||
|
return (
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q,
|
||||||
|
kv,
|
||||||
|
output_pad_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
q_unpad,
|
||||||
|
k_unpad,
|
||||||
|
v_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
output_pad_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def mistral_model_forward(
|
def mistral_model_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
|||||||
Reference in New Issue
Block a user