handle xformers patch for inference too
This commit is contained in:
@@ -65,6 +65,7 @@ def xformers_attention_forward(
|
|||||||
key = key.reshape(-1, key.size(-2), key.size(-1))
|
key = key.reshape(-1, key.size(-2), key.size(-1))
|
||||||
value = value.reshape(-1, value.size(-2), value.size(-1))
|
value = value.reshape(-1, value.size(-2), value.size(-1))
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
if module.config.num_attention_heads != module.config.num_key_value_heads:
|
if module.config.num_attention_heads != module.config.num_key_value_heads:
|
||||||
key = key.repeat_interleave(
|
key = key.repeat_interleave(
|
||||||
module.config.num_attention_heads // module.config.num_key_value_heads,
|
module.config.num_attention_heads // module.config.num_key_value_heads,
|
||||||
@@ -75,20 +76,12 @@ def xformers_attention_forward(
|
|||||||
dim=2,
|
dim=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = xformers_attention(
|
attn_output = xformers_attention(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
attn_bias=attn_bias,
|
attn_bias=attn_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
|
||||||
attn_output = xformers_attention(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
attn_bias=attn_bias,
|
|
||||||
)
|
|
||||||
elif attention_mask is not None:
|
elif attention_mask is not None:
|
||||||
batch_size = query.shape[0]
|
batch_size = query.shape[0]
|
||||||
query, key, value, indices_q, cu_seq_lens, _ = _upad_input(
|
query, key, value, indices_q, cu_seq_lens, _ = _upad_input(
|
||||||
@@ -111,6 +104,17 @@ def xformers_attention_forward(
|
|||||||
)
|
)
|
||||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||||
else:
|
else:
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
if module.config.num_attention_heads != module.config.num_key_value_heads:
|
||||||
|
key = key.repeat_interleave(
|
||||||
|
module.config.num_attention_heads // module.config.num_key_value_heads,
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
value = value.repeat_interleave(
|
||||||
|
module.config.num_attention_heads // module.config.num_key_value_heads,
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
attn_output = xformers_attention(
|
attn_output = xformers_attention(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
|
|||||||
Reference in New Issue
Block a user