diff --git a/src/axolotl/integrations/lolcats/linear_llama/linear_attention.py b/src/axolotl/integrations/lolcats/linear_llama/linear_attention.py index 5f0d4fcd1..a21245374 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/linear_attention.py +++ b/src/axolotl/integrations/lolcats/linear_llama/linear_attention.py @@ -365,7 +365,7 @@ class LolcatsLinearAttention(nn.Module): ..., None ] # b, 1, k_len, 1 else: - lin_attn_mask = attention_mask[:, None, :, None] # b, 1, k_len, 1 + lin_attn_mask = attention_mask.bool()[:, None, :, None] # b, 1, k_len, 1 k = k.masked_fill(~lin_attn_mask, 0) if past_key_value is not None: # Initialize states