fix: lin_attn_mask in wrong dtype
This commit is contained in:
@@ -365,7 +365,7 @@ class LolcatsLinearAttention(nn.Module):
|
||||
..., None
|
||||
] # b, 1, k_len, 1
|
||||
else:
|
||||
lin_attn_mask = attention_mask[:, None, :, None] # b, 1, k_len, 1
|
||||
lin_attn_mask = attention_mask.bool()[:, None, :, None] # b, 1, k_len, 1
|
||||
k = k.masked_fill(~lin_attn_mask, 0)
|
||||
|
||||
if past_key_value is not None: # Initialize states
|
||||
|
||||
Reference in New Issue
Block a user