fix: lin_attn_mask in wrong dtype

This commit is contained in:
NanoCode012
2025-02-06 15:25:33 +07:00
parent caa49a9d7d
commit ebd406af1d

View File

@@ -365,7 +365,7 @@ class LolcatsLinearAttention(nn.Module):
..., None
] # b, 1, k_len, 1
else:
lin_attn_mask = attention_mask[:, None, :, None] # b, 1, k_len, 1
lin_attn_mask = attention_mask.bool()[:, None, :, None] # b, 1, k_len, 1
k = k.masked_fill(~lin_attn_mask, 0)
if past_key_value is not None: # Initialize states