Lint flash_attn.py

This commit is contained in:
NanoCode012
2023-05-29 09:50:36 +09:00
parent 903ea3080d
commit 1c60c10e00

View File

@@ -1,9 +1,10 @@
"""Flash attention monkey patch for llama model"""
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
from typing import List, Optional, Tuple
from typing import Optional, Tuple
import torch
from torch import nn
import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
@@ -14,7 +15,7 @@ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
def forward(
def forward( # pylint: disable=too-many-arguments
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
@@ -82,6 +83,8 @@ def forward(
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
# pylint: disable=invalid-name
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
@@ -104,13 +107,13 @@ def forward(
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
): # pylint: disable=unused-argument
# [bsz, seq_len]
return attention_mask
def replace_llama_attn_with_flash_attn():
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward