diff --git a/src/axolotl/flash_attn.py b/src/axolotl/flash_attn.py index 6df0b8e18..406dd15ad 100644 --- a/src/axolotl/flash_attn.py +++ b/src/axolotl/flash_attn.py @@ -25,6 +25,7 @@ def forward( attention_mask: [bsz, q_len] """ + # pylint: disable=duplicate-code bsz, q_len, _ = hidden_states.size() query_states = ( diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index ee013bd30..c6bdafb89 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -35,6 +35,7 @@ def xformers_forward( output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # pylint: disable=duplicate-code bsz, q_len, _ = hidden_states.size() query_states = ( @@ -143,6 +144,7 @@ def sdp_attention_forward( output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # pylint: disable=duplicate-code bsz, q_len, _ = hidden_states.size() query_states = (