handle xformers patch for inference too

This commit is contained in:
Wing Lian
2025-05-05 03:22:02 -04:00
parent 5b2bd75aba
commit 82453bab7e

View File

@@ -65,6 +65,7 @@ def xformers_attention_forward(
key = key.reshape(-1, key.size(-2), key.size(-1))
value = value.reshape(-1, value.size(-2), value.size(-1))
# pylint: disable=duplicate-code
if module.config.num_attention_heads != module.config.num_key_value_heads:
key = key.repeat_interleave(
module.config.num_attention_heads // module.config.num_key_value_heads,
@@ -75,20 +76,12 @@ def xformers_attention_forward(
dim=2,
)
attn_output = xformers_attention(
query,
key,
value,
attn_bias=attn_bias,
)
else:
attn_output = xformers_attention(
query,
key,
value,
attn_bias=attn_bias,
)
attn_output = xformers_attention(
query,
key,
value,
attn_bias=attn_bias,
)
elif attention_mask is not None:
batch_size = query.shape[0]
query, key, value, indices_q, cu_seq_lens, _ = _upad_input(
@@ -111,6 +104,17 @@ def xformers_attention_forward(
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
# pylint: disable=duplicate-code
if module.config.num_attention_heads != module.config.num_key_value_heads:
key = key.repeat_interleave(
module.config.num_attention_heads // module.config.num_key_value_heads,
dim=2,
)
value = value.repeat_interleave(
module.config.num_attention_heads // module.config.num_key_value_heads,
dim=2,
)
attn_output = xformers_attention(
query,
key,