fix: lin_attn_mask in wrong dtype
This commit is contained in:
@@ -365,7 +365,7 @@ class LolcatsLinearAttention(nn.Module):
|
|||||||
..., None
|
..., None
|
||||||
] # b, 1, k_len, 1
|
] # b, 1, k_len, 1
|
||||||
else:
|
else:
|
||||||
lin_attn_mask = attention_mask[:, None, :, None] # b, 1, k_len, 1
|
lin_attn_mask = attention_mask.bool()[:, None, :, None] # b, 1, k_len, 1
|
||||||
k = k.masked_fill(~lin_attn_mask, 0)
|
k = k.masked_fill(~lin_attn_mask, 0)
|
||||||
|
|
||||||
if past_key_value is not None: # Initialize states
|
if past_key_value is not None: # Initialize states
|
||||||
|
|||||||
Reference in New Issue
Block a user