handle xformers patch for inference too

This commit is contained in:
Wing Lian
2025-05-05 03:22:02 -04:00
parent 5b2bd75aba
commit 82453bab7e

View File

@@ -65,6 +65,7 @@ def xformers_attention_forward(
key = key.reshape(-1, key.size(-2), key.size(-1)) key = key.reshape(-1, key.size(-2), key.size(-1))
value = value.reshape(-1, value.size(-2), value.size(-1)) value = value.reshape(-1, value.size(-2), value.size(-1))
# pylint: disable=duplicate-code
if module.config.num_attention_heads != module.config.num_key_value_heads: if module.config.num_attention_heads != module.config.num_key_value_heads:
key = key.repeat_interleave( key = key.repeat_interleave(
module.config.num_attention_heads // module.config.num_key_value_heads, module.config.num_attention_heads // module.config.num_key_value_heads,
@@ -75,20 +76,12 @@ def xformers_attention_forward(
dim=2, dim=2,
) )
attn_output = xformers_attention( attn_output = xformers_attention(
query, query,
key, key,
value, value,
attn_bias=attn_bias, attn_bias=attn_bias,
) )
else:
attn_output = xformers_attention(
query,
key,
value,
attn_bias=attn_bias,
)
elif attention_mask is not None: elif attention_mask is not None:
batch_size = query.shape[0] batch_size = query.shape[0]
query, key, value, indices_q, cu_seq_lens, _ = _upad_input( query, key, value, indices_q, cu_seq_lens, _ = _upad_input(
@@ -111,6 +104,17 @@ def xformers_attention_forward(
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else: else:
# pylint: disable=duplicate-code
if module.config.num_attention_heads != module.config.num_key_value_heads:
key = key.repeat_interleave(
module.config.num_attention_heads // module.config.num_key_value_heads,
dim=2,
)
value = value.repeat_interleave(
module.config.num_attention_heads // module.config.num_key_value_heads,
dim=2,
)
attn_output = xformers_attention( attn_output = xformers_attention(
query, query,
key, key,