From 1c60c10e0076de9ff6e9c115a6919281ba3b3ecd Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 09:50:36 +0900 Subject: [PATCH] Lint flash_attn.py --- src/axolotl/flash_attn.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/axolotl/flash_attn.py b/src/axolotl/flash_attn.py index c1ceec788..d532e15a8 100644 --- a/src/axolotl/flash_attn.py +++ b/src/axolotl/flash_attn.py @@ -1,9 +1,10 @@ +"""Flash attention monkey patch for llama model""" + # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py -from typing import List, Optional, Tuple +from typing import Optional, Tuple import torch -from torch import nn import transformers from transformers.models.llama.modeling_llama import apply_rotary_pos_emb @@ -14,7 +15,7 @@ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func from flash_attn.bert_padding import unpad_input, pad_input -def forward( +def forward( # pylint: disable=too-many-arguments self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -82,6 +83,8 @@ def forward( output = rearrange(output, "(b s) ... -> b s ...", b=bsz) else: nheads = qkv.shape[-2] + + # pylint: disable=invalid-name x = rearrange(qkv, "b s three h d -> b s (three h d)") x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) x_unpad = rearrange( @@ -104,13 +107,13 @@ def forward( # requires the attention mask to be the same as the key_padding_mask def _prepare_decoder_attention_mask( self, attention_mask, input_shape, inputs_embeds, past_key_values_length -): +): # pylint: disable=unused-argument # [bsz, seq_len] return attention_mask def replace_llama_attn_with_flash_attn(): - transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access _prepare_decoder_attention_mask ) transformers.models.llama.modeling_llama.LlamaAttention.forward = forward