handle xformers patch for inference too
This commit is contained in:
@@ -65,6 +65,7 @@ def xformers_attention_forward(
|
||||
key = key.reshape(-1, key.size(-2), key.size(-1))
|
||||
value = value.reshape(-1, value.size(-2), value.size(-1))
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
if module.config.num_attention_heads != module.config.num_key_value_heads:
|
||||
key = key.repeat_interleave(
|
||||
module.config.num_attention_heads // module.config.num_key_value_heads,
|
||||
@@ -75,20 +76,12 @@ def xformers_attention_forward(
|
||||
dim=2,
|
||||
)
|
||||
|
||||
attn_output = xformers_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias,
|
||||
)
|
||||
|
||||
else:
|
||||
attn_output = xformers_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias,
|
||||
)
|
||||
attn_output = xformers_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias,
|
||||
)
|
||||
elif attention_mask is not None:
|
||||
batch_size = query.shape[0]
|
||||
query, key, value, indices_q, cu_seq_lens, _ = _upad_input(
|
||||
@@ -111,6 +104,17 @@ def xformers_attention_forward(
|
||||
)
|
||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||
else:
|
||||
# pylint: disable=duplicate-code
|
||||
if module.config.num_attention_heads != module.config.num_key_value_heads:
|
||||
key = key.repeat_interleave(
|
||||
module.config.num_attention_heads // module.config.num_key_value_heads,
|
||||
dim=2,
|
||||
)
|
||||
value = value.repeat_interleave(
|
||||
module.config.num_attention_heads // module.config.num_key_value_heads,
|
||||
dim=2,
|
||||
)
|
||||
|
||||
attn_output = xformers_attention(
|
||||
query,
|
||||
key,
|
||||
|
||||
Reference in New Issue
Block a user