Lint flash_attn.py
This commit is contained in:
@@ -1,9 +1,10 @@
|
|||||||
|
"""Flash attention monkey patch for llama model"""
|
||||||
|
|
||||||
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||||
@@ -14,7 +15,7 @@ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
|||||||
from flash_attn.bert_padding import unpad_input, pad_input
|
from flash_attn.bert_padding import unpad_input, pad_input
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward( # pylint: disable=too-many-arguments
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
@@ -82,6 +83,8 @@ def forward(
|
|||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
else:
|
else:
|
||||||
nheads = qkv.shape[-2]
|
nheads = qkv.shape[-2]
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
||||||
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
||||||
x_unpad = rearrange(
|
x_unpad = rearrange(
|
||||||
@@ -104,13 +107,13 @@ def forward(
|
|||||||
# requires the attention mask to be the same as the key_padding_mask
|
# requires the attention mask to be the same as the key_padding_mask
|
||||||
def _prepare_decoder_attention_mask(
|
def _prepare_decoder_attention_mask(
|
||||||
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||||
):
|
): # pylint: disable=unused-argument
|
||||||
# [bsz, seq_len]
|
# [bsz, seq_len]
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
def replace_llama_attn_with_flash_attn():
|
def replace_llama_attn_with_flash_attn():
|
||||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
|
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||||
_prepare_decoder_attention_mask
|
_prepare_decoder_attention_mask
|
||||||
)
|
)
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
||||||
|
|||||||
Reference in New Issue
Block a user