From c56818b11978bafac83ed5e4949cd4d9aa0f6326 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 31 May 2023 00:06:47 -0400 Subject: [PATCH] don't worry about dupes --- src/axolotl/flash_attn.py | 1 + src/axolotl/monkeypatch/llama_attn_hijack_xformers.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/axolotl/flash_attn.py b/src/axolotl/flash_attn.py index 6df0b8e18..406dd15ad 100644 --- a/src/axolotl/flash_attn.py +++ b/src/axolotl/flash_attn.py @@ -25,6 +25,7 @@ def forward( attention_mask: [bsz, q_len] """ + # pylint: disable=duplicate-code bsz, q_len, _ = hidden_states.size() query_states = ( diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index ee013bd30..c6bdafb89 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -35,6 +35,7 @@ def xformers_forward( output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # pylint: disable=duplicate-code bsz, q_len, _ = hidden_states.size() query_states = ( @@ -143,6 +144,7 @@ def sdp_attention_forward( output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # pylint: disable=duplicate-code bsz, q_len, _ = hidden_states.size() query_states = (