diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index f727c74b8..be50dcdb2 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -245,7 +245,6 @@ def flashattn_forward_with_s2attn( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument cu_seqlens: Optional[torch.Tensor] = None, # pylint: disable=unused-argument max_seqlen: Optional[torch.Tensor] = None, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -374,7 +373,6 @@ def flashattn_forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -770,12 +768,6 @@ def llama_model_forward( dtype=torch.bool, device=inputs_embeds.device, ) - padding_mask = None - else: - if 0 in attention_mask: - padding_mask = attention_mask - else: - padding_mask = None attention_mask = ( self._prepare_decoder_attention_mask( # pylint: disable=protected-access @@ -825,7 +817,6 @@ def llama_model_forward( past_key_value, output_attentions, None, - padding_mask, cu_seqlens, max_seqlen, ) @@ -837,7 +828,6 @@ def llama_model_forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - padding_mask=padding_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) @@ -884,7 +874,6 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer): past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - padding_mask: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[torch.Tensor] = None, ) -> Tuple[ @@ -917,7 +906,6 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer): past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - padding_mask=padding_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, )