Lint flash_attn.py
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
"""Flash attention monkey patch for llama model"""
|
||||
|
||||
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import transformers
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
@@ -14,7 +15,7 @@ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
||||
from flash_attn.bert_padding import unpad_input, pad_input
|
||||
|
||||
|
||||
def forward(
|
||||
def forward( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
@@ -82,6 +83,8 @@ def forward(
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
else:
|
||||
nheads = qkv.shape[-2]
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
||||
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
||||
x_unpad = rearrange(
|
||||
@@ -104,13 +107,13 @@ def forward(
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
):
|
||||
): # pylint: disable=unused-argument
|
||||
# [bsz, seq_len]
|
||||
return attention_mask
|
||||
|
||||
|
||||
def replace_llama_attn_with_flash_attn():
|
||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
|
||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||
_prepare_decoder_attention_mask
|
||||
)
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
||||
|
||||
Reference in New Issue
Block a user