From ebd406af1d29dc7ae41c3ab1fa25e55c52e4977a Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 6 Feb 2025 15:25:33 +0700 Subject: [PATCH] fix: lin_attn_mask in wrong dtype --- .../integrations/lolcats/linear_llama/linear_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/integrations/lolcats/linear_llama/linear_attention.py b/src/axolotl/integrations/lolcats/linear_llama/linear_attention.py index 5f0d4fcd1..a21245374 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/linear_attention.py +++ b/src/axolotl/integrations/lolcats/linear_llama/linear_attention.py @@ -365,7 +365,7 @@ class LolcatsLinearAttention(nn.Module): ..., None ] # b, 1, k_len, 1 else: - lin_attn_mask = attention_mask[:, None, :, None] # b, 1, k_len, 1 + lin_attn_mask = attention_mask.bool()[:, None, :, None] # b, 1, k_len, 1 k = k.masked_fill(~lin_attn_mask, 0) if past_key_value is not None: # Initialize states